# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch GPT-J model.""" import math import torch from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ... import initialization as init from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, QuestionAnsweringModelOutput, SequenceClassifierOutputWithPast, ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging from .configuration_gptj import GPTJConfig if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward logger = logging.get_logger(__name__) def create_sinusoidal_positions(num_pos: int, dim: int) -> torch.Tensor: inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim)) sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float() return torch.cat((torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)), dim=1) def get_embed_positions(embed_positions, position_ids): return embed_positions.to(position_ids.device).repeat(position_ids.shape[0], 1, 1) def rotate_every_two(x: torch.Tensor) -> torch.Tensor: x1 = x[:, :, :, ::2] x2 = x[:, :, :, 1::2] x = torch.stack((-x2, x1), dim=-1) return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)') def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Tensor) -> torch.Tensor: sin = torch.repeat_interleave(sin[:, :, None, :], 2, 3) cos = torch.repeat_interleave(cos[:, :, None, :], 2, 3) return (tensor * cos) + (rotate_every_two(tensor) * sin) class GPTJAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() self.config = config self.max_positions = config.max_position_embeddings self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) self.is_causal = True self.layer_idx = layer_idx if layer_idx is None: logger.warning_once( f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " "when creating this class." ) self.embed_dim = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_attention_heads if self.head_dim * self.num_attention_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" f" `num_attention_heads`: {self.num_attention_heads})." ) self.scale_attn = math.sqrt(self.head_dim) self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.rotary_dim = config.rotary_dim self.pos_embd_dim = self.rotary_dim or self.embed_dim self.register_buffer( "embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False ) def _split_heads(self, tensor, num_attention_heads, attn_head_size, rotary): """ Splits hidden dim into attn_head_size and num_attention_heads """ new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) tensor = tensor.view(new_shape) if rotary: return tensor if len(tensor.shape) == 5: return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features) elif len(tensor.shape) == 4: return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) else: raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") def _merge_heads(self, tensor, num_attention_heads, attn_head_size): """ Merges attn_head_size dim and num_attn_heads dim into hidden dim """ if len(tensor.shape) == 5: tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() elif len(tensor.shape) == 4: tensor = tensor.permute(0, 2, 1, 3).contiguous() else: raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) return tensor.view(new_shape) def _attn( self, query, key, value, attention_mask=None, ): # Keep the attention weights computation in fp32 to avoid overflow issues query = query.to(torch.float32) key = key.to(torch.float32) attn_weights = torch.matmul(query, key.transpose(-1, -2)) attn_weights = attn_weights / self.scale_attn if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) attn_weights = self.attn_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights def _get_embed_positions(self, position_ids): embed_positions = self.embed_positions if embed_positions.device != position_ids.device: embed_positions = embed_positions.to(position_ids.device) self.embed_positions = embed_positions return embed_positions.repeat(position_ids.shape[0], 1, 1) def forward( self, hidden_states: torch.FloatTensor, layer_past: Cache | None = None, attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, output_attentions: bool | None = False, **kwargs, ) -> ( tuple[torch.Tensor, tuple[torch.Tensor]] | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]] | None ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) embed_positions = self._get_embed_positions(position_ids) repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] k_rot = apply_rotary_pos_emb(k_rot, sin, cos) q_rot = apply_rotary_pos_emb(q_rot, sin, cos) key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) else: key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) key = key.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3) if layer_past is not None: key, value = layer_past.update(key, value, self.layer_idx) # compute self-attention: V x Softmax(QK^T) attn_output, attn_weights = self._attn(query, key, value, attention_mask) attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output, attn_weights class GPTJFlashAttention2(GPTJAttention): """ GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() def forward( self, hidden_states: torch.FloatTensor, layer_past: Cache | None = None, attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, output_attentions: bool | None = False, **kwargs, ) -> ( tuple[torch.Tensor, tuple[torch.Tensor]] | tuple[torch.Tensor, tuple[torch.Tensor], tuple[torch.Tensor, ...]] | None ): query = self.q_proj(hidden_states) key = self.k_proj(hidden_states) value = self.v_proj(hidden_states) query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) embed_positions = self._get_embed_positions(position_ids) repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) sincos = torch.gather(embed_positions, 1, repeated_position_ids).to(key.dtype) sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) if self.rotary_dim is not None: k_rot = key[:, :, :, : self.rotary_dim] k_pass = key[:, :, :, self.rotary_dim :] q_rot = query[:, :, :, : self.rotary_dim] q_pass = query[:, :, :, self.rotary_dim :] k_rot = apply_rotary_pos_emb(k_rot, sin, cos) q_rot = apply_rotary_pos_emb(q_rot, sin, cos) key = torch.cat([k_rot, k_pass], dim=-1) query = torch.cat([q_rot, q_pass], dim=-1) else: key = apply_rotary_pos_emb(key, sin, cos) query = apply_rotary_pos_emb(query, sin, cos) # tanspose to have the desired shape # before transpose: batch_size x seq_length x num_attention_heads x head_dim # after transpose: batch_size x num_attention_heads x seq_length x head_dim key = key.permute(0, 2, 1, 3) query = query.permute(0, 2, 1, 3) # value: batch_size x num_attention_heads x seq_length x head_dim if layer_past is not None: key, value = layer_past.update(key, value, self.layer_idx) # The Flash attention requires the input to have the shape # batch_size x seq_length x head_dim x hidden_dim # therefore we need to keep the original shape for query and key, and reshape value # to have the correct shape. key = key.permute(0, 2, 1, 3).contiguous() query = query.permute(0, 2, 1, 3).contiguous() value = value.permute(0, 2, 1, 3).contiguous() # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) input_dtype = query.dtype device_type = query.device.type if query.device.type != "mps" else "cpu" if input_dtype == torch.float32: if torch.is_autocast_enabled(device_type): target_dtype = torch.get_autocast_dtype(device_type) # Handle the case where the model is quantized elif hasattr(self.config, "_is_quantized"): target_dtype = self.config.dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query = query.to(target_dtype) key = key.to(target_dtype) value = value.to(target_dtype) attention_dropout = self.config.attn_pdrop if self.training else 0.0 # attn_pdrop in gptj query_length = query.shape[1] # Compute attention attn_weights = _flash_attention_forward( query, key, value, attention_mask, query_length, dropout=attention_dropout, is_causal=self.is_causal, use_top_left_mask=self._flash_attn_uses_top_left_mask, ) # Reshape outputs attn_output = attn_weights.reshape( attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3] ) attn_output = self.out_proj(attn_output) attn_output = self.resid_dropout(attn_output) return attn_output, attn_weights GPTJ_ATTENTION_CLASSES = { "eager": GPTJAttention, "flash_attention_2": GPTJFlashAttention2, } class GPTJMLP(nn.Module): def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim super().__init__() embed_dim = config.n_embd self.fc_in = nn.Linear(embed_dim, intermediate_size) self.fc_out = nn.Linear(intermediate_size, embed_dim) self.act = ACT2FN[config.activation_function] self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states: torch.FloatTensor | None) -> torch.FloatTensor: hidden_states = self.fc_in(hidden_states) hidden_states = self.act(hidden_states) hidden_states = self.fc_out(hidden_states) hidden_states = self.dropout(hidden_states) return hidden_states class GPTJBlock(GradientCheckpointingLayer): def __init__(self, config, layer_idx=None): super().__init__() inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.attn = GPTJ_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) self.mlp = GPTJMLP(inner_dim, config) def forward( self, hidden_states: torch.FloatTensor | None, layer_past: Cache | None = None, attention_mask: torch.FloatTensor | None = None, position_ids: torch.LongTensor | None = None, use_cache: bool | None = False, output_attentions: bool | None = False, **kwargs, ) -> tuple[torch.Tensor] | tuple[torch.Tensor, tuple[torch.FloatTensor, ...]] | None: residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, ) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_outputs + feed_forward_hidden_states + residual return hidden_states, attn_weights @auto_docstring class GPTJPreTrainedModel(PreTrainedModel): config: GPTJConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn = True _can_compile_fullgraph = True def _init_weights(self, module): super()._init_weights(module) if isinstance(module, GPTJAttention): init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim)) @auto_docstring class GPTJModel(GPTJPreTrainedModel): def __init__(self, config): super().__init__(config) self.embed_dim = config.n_embd self.vocab_size = config.vocab_size self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([GPTJBlock(config, layer_idx=i) for i in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.wte def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, attention_mask: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple | BaseModelOutputWithPast: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors than the model's internal embedding lookup matrix. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.wte(input_ids) if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) seq_length = inputs_embeds.shape[1] if position_ids is None: past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length position_ids = position_ids.unsqueeze(0) causal_mask = create_causal_mask( config=self.config, inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, position_ids=position_ids, ) hidden_states = inputs_embeds if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, seq_length) token_type_embeds = self.wte(token_type_ids) hidden_states = hidden_states + token_type_embeds hidden_states = self.drop(hidden_states) output_shape = (-1, seq_length, hidden_states.size(-1)) all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, layer_past=past_key_values, attention_mask=causal_mask, position_ids=position_ids, use_cache=use_cache, output_attentions=output_attentions, ) hidden_states = outputs[0] if output_attentions: all_self_attentions = all_self_attentions + (outputs[1],) hidden_states = self.ln_f(hidden_states) hidden_states = hidden_states.view(output_shape) # Add last hidden state if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions] if v is not None ) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, ) @auto_docstring( custom_intro=""" The GPT-J Model transformer with a language modeling head on top. """ ) class GPTJForCausalLM(GPTJPreTrainedModel, GenerationMixin): _tied_weights_keys = {"lm_head.weight": "transformer.wte.weight"} def __init__(self, config): super().__init__(config) self.transformer = GPTJModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, attention_mask: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs, ) -> tuple | CausalLMOutputWithPast: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ return_dict = return_dict if return_dict is not None else self.config.return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @auto_docstring( custom_intro=""" The GPT-J Model transformer with a sequence classification head on top (linear layer). [`GPTJForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT, GPT-2, GPT-Neo) do. Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """ ) class GPTJForSequenceClassification(GPTJPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.transformer = GPTJModel(config) self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, attention_mask: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple | SequenceClassifierOutputWithPast: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors than the model's internal embedding lookup matrix. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.return_dict transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) if input_ids is not None: batch_size = input_ids.shape[0] else: batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") if self.config.pad_token_id is None: last_non_pad_token = -1 elif input_ids is not None: # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) else: last_non_pad_token = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] loss = None if labels is not None: labels = labels.to(pooled_logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @auto_docstring class GPTJForQuestionAnswering(GPTJPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.transformer = GPTJModel(config) self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) # Initialize weights and apply final processing self.post_init() @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, token_type_ids: torch.LongTensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, start_positions: torch.LongTensor | None = None, end_positions: torch.LongTensor | None = None, output_attentions: bool | None = None, output_hidden_states: bool | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple | QuestionAnsweringModelOutput: r""" inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert *input_ids* indices into associated vectors than the model's internal embedding lookup matrix. """ return_dict = return_dict if return_dict is not None else self.config.return_dict outputs = self.transformer( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = outputs[0] logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1).contiguous() end_logits = end_logits.squeeze(-1).contiguous() total_loss = None if start_positions is not None and end_positions is not None: # If we are on multi-GPU, split add a dimension if len(start_positions.size()) > 1: start_positions = start_positions.squeeze(-1).to(start_logits.device) if len(end_positions.size()) > 1: end_positions = end_positions.squeeze(-1).to(end_logits.device) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = start_logits.size(1) start_positions = start_positions.clamp(0, ignored_index) end_positions = end_positions.clamp(0, ignored_index) loss_fct = CrossEntropyLoss(ignore_index=ignored_index) start_loss = loss_fct(start_logits, start_positions) end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 if not return_dict: output = (start_logits, end_logits) + outputs[2:] return ((total_loss,) + output) if total_loss is not None else output return QuestionAnsweringModelOutput( loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) __all__ = [ "GPTJForCausalLM", "GPTJForQuestionAnswering", "GPTJForSequenceClassification", "GPTJModel", "GPTJPreTrainedModel", ]