# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """PyTorch BioGPT model.""" import math import torch import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...masking_utils import create_causal_mask from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import ( TransformersKwargs, auto_docstring, can_return_tuple, logger, ) from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..bart.modeling_bart import ( BartAttention, BartDecoderLayer, BartScaledWordEmbedding, ) from ..opt.modeling_opt import OPTLearnedPositionalEmbedding from .configuration_biogpt import BioGptConfig class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding): def forward( self, attention_mask: torch.LongTensor, past_key_values_length: int = 0, position_ids: torch.LongTensor | None = None, ): """`input_ids_shape` is expected to be [bsz x seqlen].""" return super().forward(attention_mask, past_key_values_length, position_ids) class BioGptScaledWordEmbedding(BartScaledWordEmbedding): pass class BioGptAttention(BartAttention): pass class BioGptDecoderLayer(BartDecoderLayer): def __init__(self, config: BioGptConfig, layer_idx: int | None = None): super().__init__(config) self.embed_dim = config.hidden_size self.self_attn = BioGptAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_probs_dropout_prob, is_decoder=True, is_causal=True, config=config, layer_idx=layer_idx, ) self.dropout = config.hidden_dropout_prob self.activation_fn = ACT2FN[config.hidden_act] self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size) self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) del self.encoder_attn del self.encoder_attn_layer_norm def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, use_cache: bool | None = True, position_ids: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. past_key_values (`Cache`): cached past key and value projection states """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention hidden_states, _ = self.self_attn( hidden_states=hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.fc1(hidden_states) hidden_states = self.activation_fn(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states return hidden_states @auto_docstring class BioGptPreTrainedModel(PreTrainedModel): config: BioGptConfig base_model_prefix = "biogpt" supports_gradient_checkpointing = True _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True _can_compile_fullgraph = True _can_record_outputs = { "hidden_states": BioGptDecoderLayer, "attentions": BioGptAttention, } @auto_docstring class BioGptModel(BioGptPreTrainedModel): def __init__(self, config: BioGptConfig): super().__init__(config) self.config = config self.layerdrop = config.layerdrop self.dropout = config.hidden_dropout_prob self.embed_dim = config.hidden_size self.padding_idx = config.pad_token_id embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 self.embed_tokens = BioGptScaledWordEmbedding( config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale ) self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(self.embed_dim) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() @merge_with_config_defaults @capture_outputs @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, past_key_values: Cache | None = None, use_cache: bool | None = None, position_ids: torch.LongTensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPastAndCrossAttentions: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) # initialize past_key_values if use_cache and past_key_values is None: past_key_values = DynamicCache(config=self.config) batch_size, seq_length = inputs_embeds.size()[:-1] past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if attention_mask is None: # required mask seq length can be calculated via length of past cache mask_seq_length = past_key_values_length + seq_length attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) self_attn_cache = past_key_values causal_mask = create_causal_mask( config=self.config, input_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=self_attn_cache, ) # embed positions if position_ids is None: position_ids = torch.arange(seq_length, device=inputs_embeds.device) + past_key_values_length position_ids = position_ids.unsqueeze(0) positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) for idx, decoder_layer in enumerate(self.layers): if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue hidden_states = decoder_layer( hidden_states, attention_mask=causal_mask, past_key_values=past_key_values, use_cache=use_cache, position_ids=position_ids, **kwargs, ) hidden_states = self.layer_norm(hidden_states) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=past_key_values, ) @auto_docstring( custom_intro=""" BioGPT Model with a `language modeling` head on top for CLM fine-tuning. """ ) class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): _tied_weights_keys = {"output_projection.weight": "biogpt.embed_tokens.weight"} def __init__(self, config): super().__init__(config) self.biogpt = BioGptModel(config) self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): return self.output_projection def set_output_embeddings(self, new_embeddings): self.output_projection = new_embeddings @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, inputs_embeds: torch.FloatTensor | None = None, past_key_values: Cache | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, position_ids: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithCrossAttentions: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` """ outputs = self.biogpt( input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, position_ids=position_ids, **kwargs, ) hidden_states = outputs[0] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.output_projection(hidden_states[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, cross_attentions=outputs.cross_attentions, ) @auto_docstring class BioGptForTokenClassification(BioGptPreTrainedModel): def __init__(self, config): super().__init__(config) self.num_labels = config.num_labels self.biogpt = BioGptModel(config) if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: classifier_dropout = config.classifier_dropout else: classifier_dropout = config.hidden_dropout_prob self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, token_type_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, position_ids: torch.LongTensor | None = None, **kwargs, ) -> tuple | TokenClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ transformer_outputs = self.biogpt( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, **kwargs, ) hidden_states = transformer_outputs[0] hidden_states = self.dropout(hidden_states) logits = self.classifier(hidden_states) loss = None if labels is not None: loss_fct = CrossEntropyLoss() if attention_mask is not None: active_loss = attention_mask.view(-1) == 1 active_logits = logits.view(-1, self.num_labels) active_labels = torch.where( active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) ) loss = loss_fct(active_logits, active_labels) else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return TokenClassifierOutput( loss=loss, logits=logits, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) @auto_docstring( custom_intro=""" The BioGpt Model transformer with a sequence classification head on top (linear layer). [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. Since it does classification on the last token, it is required to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in each row of the batch). """ ) class BioGptForSequenceClassification(BioGptPreTrainedModel): def __init__(self, config: BioGptConfig): super().__init__(config) self.num_labels = config.num_labels self.biogpt = BioGptModel(config) self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) # Initialize weights and apply final processing self.post_init() @can_return_tuple @auto_docstring def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.FloatTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, labels: torch.LongTensor | None = None, use_cache: bool | None = None, position_ids: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs, ) -> tuple | SequenceClassifierOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ transformer_outputs = self.biogpt( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, position_ids=position_ids, **kwargs, ) hidden_states = transformer_outputs[0] slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.score(hidden_states[:, slice_indices, :]) if input_ids is not None: batch_size, sequence_length = input_ids.shape[:2] else: batch_size, sequence_length = inputs_embeds.shape[:2] if self.config.pad_token_id is None: sequence_length = -1 else: if input_ids is not None: sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) else: sequence_length = -1 logger.warning_once( f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " "unexpected if using padding tokens in conjunction with `inputs_embeds.`" ) pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = MSELoss() if self.num_labels == 1: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) else: loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) return SequenceClassifierOutputWithPast( loss=loss, logits=pooled_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, ) def get_input_embeddings(self): return self.biogpt.embed_tokens def set_input_embeddings(self, value): self.biogpt.embed_tokens = value __all__ = [ "BioGptForCausalLM", "BioGptForTokenClassification", "BioGptForSequenceClassification", "BioGptModel", "BioGptPreTrainedModel", ]