| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485 |
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # This file was automatically generated from src/transformers/models/evolla/modular_evolla.py.
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
- # the file from the modular. If any change should be done, please apply the change to the
- # modular_evolla.py file directly. One of our CI enforces this.
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
- # Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- from collections.abc import Callable
- from dataclasses import dataclass
- from typing import Optional
- import torch
- from torch import nn
- from ... import initialization as init
- from ...activations import ACT2FN
- from ...cache_utils import Cache, DynamicCache
- from ...generation import GenerationMixin
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
- from ...masking_utils import create_bidirectional_mask, create_causal_mask
- from ...modeling_layers import GradientCheckpointingLayer
- from ...modeling_outputs import (
- BaseModelOutputWithCrossAttentions,
- BaseModelOutputWithPast,
- BaseModelOutputWithPoolingAndCrossAttentions,
- CausalLMOutputWithPast,
- ModelOutput,
- )
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
- from ...processing_utils import Unpack
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
- from ...utils.generic import maybe_autocast, merge_with_config_defaults
- from ...utils.output_capturing import OutputRecorder, capture_outputs
- from .configuration_evolla import EvollaConfig, SaProtConfig
- def create_position_ids_from_input_ids(input_ids, padding_idx):
- """
- Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
- are ignored. This is modified from fairseq's `utils.make_positions`.
- Args:
- x: torch.Tensor x:
- Returns: torch.Tensor
- """
- # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
- mask = input_ids.ne(padding_idx).int()
- incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
- return incremental_indices.long() + padding_idx
- class EvollaSaProtEmbeddings(nn.Module):
- """
- Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
- """
- def __init__(self, config):
- super().__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
- if config.emb_layer_norm_before:
- self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- else:
- self.layer_norm = None
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
- self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
- self.register_buffer(
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
- )
- self.padding_idx = config.pad_token_id
- if self.position_embedding_type == "absolute":
- self.position_embeddings = nn.Embedding(
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
- )
- self.token_dropout = config.token_dropout
- self.mask_token_id = config.mask_token_id
- # remove the position_ids in EsmEmbeddings
- self.position_ids = None
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- position_ids=None,
- inputs_embeds=None,
- ):
- if position_ids is None:
- if input_ids is not None:
- # Create the position ids from the input token ids. Any padded tokens remain padded.
- position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
- else:
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
- # Note that if we want to support EVOLLA_SA_PROT-1 (not 1b!) in future then we need to support an
- # embedding_scale factor here.
- embeddings = inputs_embeds
- # Matt: EVOLLA_SA_PROT has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
- # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
- # masked tokens are treated as if they were selected for input dropout and zeroed out.
- # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
- # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
- # This is analogous to the way that dropout layers scale down outputs during evaluation when not
- # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
- if self.token_dropout and input_ids is not None:
- embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
- mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all EVOLLA_SA_PROT model training runs
- src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
- mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
- embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
- embeddings.dtype
- )
- if self.position_embedding_type == "absolute":
- position_embeddings = self.position_embeddings(position_ids)
- embeddings = embeddings + position_embeddings
- if self.layer_norm is not None:
- embeddings = self.layer_norm(embeddings)
- if attention_mask is not None:
- embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
- # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
- # embeddings = self.dropout(embeddings)
- return embeddings
- def create_position_ids_from_inputs_embeds(self, inputs_embeds):
- """
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
- Args:
- inputs_embeds: torch.Tensor
- Returns: torch.Tensor
- """
- input_shape = inputs_embeds.size()[:-1]
- sequence_length = input_shape[1]
- position_ids = torch.arange(
- self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
- )
- return position_ids.unsqueeze(0).expand(input_shape)
- def rotate_half_esm(x):
- x1, x2 = x.chunk(2, dim=-1)
- return torch.cat((-x2, x1), dim=-1)
- def apply_rotary_pos_emb_esm(x, cos, sin):
- cos = cos[:, :, : x.shape[-2], :]
- sin = sin[:, :, : x.shape[-2], :]
- return (x * cos) + (rotate_half_esm(x) * sin)
- class EvollaSaProtRotaryEmbedding(nn.Module):
- """
- Rotary position embeddings based on those in
- [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
- matrices which depend on their relative positions.
- """
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, dim: int):
- super().__init__()
- self.dim = dim
- # Generate and save the inverse frequency buffer (non trainable)
- inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
- self.register_buffer("inv_freq", inv_freq)
- self._seq_len_cached = None
- self._cos_cached = None
- self._sin_cached = None
- def _update_cos_sin_tables(self, x, seq_dimension=2):
- seq_len = x.shape[seq_dimension]
- # Reset the tables if the sequence length has changed,
- # or if we're on a new device (possibly due to tracing for instance)
- if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
- self._seq_len_cached = seq_len
- t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
- freqs = torch.outer(t, self.inv_freq)
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
- self._cos_cached = emb.cos()[None, None, :, :]
- self._sin_cached = emb.sin()[None, None, :, :]
- return self._cos_cached, self._sin_cached
- def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
- self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
- return (
- apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
- apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
- )
- def eager_attention_forward(
- module: nn.Module,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: torch.Tensor | None,
- scaling: float | None = None,
- dropout: float = 0.0,
- **kwargs: Unpack[TransformersKwargs],
- ):
- if scaling is None:
- scaling = query.size(-1) ** -0.5
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
- if attention_mask is not None:
- attn_weights = attn_weights + attention_mask
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
- attn_output = torch.matmul(attn_weights, value)
- attn_output = attn_output.transpose(1, 2).contiguous()
- return attn_output, attn_weights
- class EvollaSaProtSelfAttention(nn.Module):
- def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
- super().__init__()
- self.config = config
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
- raise ValueError(
- f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
- f"heads ({config.num_attention_heads})"
- )
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
- self.dropout = config.attention_probs_dropout_prob
- self.rotary_embeddings = None
- self.position_embedding_type = position_embedding_type or getattr(
- config, "position_embedding_type", "absolute"
- )
- if self.position_embedding_type == "rotary":
- self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size)
- self.is_decoder = config.is_decoder
- self.layer_idx = layer_idx
- self.scaling = 1.0
- self.is_causal = self.is_decoder and not is_cross_attention
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: torch.FloatTensor | None = None,
- encoder_hidden_states: torch.FloatTensor | None = None,
- encoder_attention_mask: torch.FloatTensor | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.attention_head_size)
- query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
- is_cross_attention = encoder_hidden_states is not None
- current_states = encoder_hidden_states if is_cross_attention else hidden_states
- attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
- key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
- value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
- # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
- # EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
- # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
- # EVOLLA_SA_PROT code and fix rotary embeddings.
- query_layer = query_layer * self.attention_head_size**-0.5
- if self.position_embedding_type == "rotary":
- query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_layer,
- key_layer,
- value_layer,
- attention_mask,
- dropout=0.0 if not self.training else self.dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- return attn_output, attn_weights
- class EvollaSaProtSelfOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = hidden_states + input_tensor
- return hidden_states
- class EvollaSaProtAttention(nn.Module):
- def __init__(self, config, layer_idx=None, is_cross_attention=False):
- super().__init__()
- self.self = EvollaSaProtSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
- self.output = EvollaSaProtSelfOutput(config)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- hidden_states_ln = self.LayerNorm(hidden_states)
- attn_output, _ = self.self(
- hidden_states_ln,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- attn_output = self.output(attn_output, hidden_states)
- return attn_output
- def gelu(x):
- """
- This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results.
- """
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
- class EvollaSaProtIntermediate(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dense(hidden_states)
- hidden_states = gelu(hidden_states)
- return hidden_states
- class EvollaSaProtOutput(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- def forward(self, hidden_states, input_tensor):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = hidden_states + input_tensor
- return hidden_states
- class EvollaSaProtLayer(GradientCheckpointingLayer):
- def __init__(self, config):
- super().__init__()
- self.chunk_size_feed_forward = config.chunk_size_feed_forward
- self.seq_len_dim = 1
- self.attention = EvollaSaProtAttention(config)
- self.is_decoder = config.is_decoder
- self.add_cross_attention = config.add_cross_attention
- if self.add_cross_attention:
- if not self.is_decoder:
- raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
- self.crossattention = EvollaSaProtAttention(config, is_cross_attention=True)
- self.intermediate = EvollaSaProtIntermediate(config)
- self.output = EvollaSaProtOutput(config)
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- attention_output = self.attention(
- hidden_states,
- attention_mask=attention_mask,
- **kwargs,
- )
- if self.is_decoder and encoder_hidden_states is not None:
- if not hasattr(self, "crossattention"):
- raise AttributeError(
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
- " with cross-attention layers by setting `config.add_cross_attention=True`"
- )
- attention_output = self.crossattention(
- attention_output,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- layer_output = self.feed_forward_chunk(attention_output)
- return layer_output
- def feed_forward_chunk(self, attention_output):
- attention_output_ln = self.LayerNorm(attention_output)
- intermediate_output = self.intermediate(attention_output_ln)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
- class EvollaSaProtEncoder(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.layer = nn.ModuleList([EvollaSaProtLayer(config) for _ in range(config.num_hidden_layers)])
- self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.gradient_checkpointing = False
- @can_return_tuple
- def forward(
- self,
- hidden_states,
- attention_mask=None,
- encoder_hidden_states=None,
- encoder_attention_mask=None,
- **kwargs: Unpack[TransformersKwargs],
- ):
- for i, layer_module in enumerate(self.layer):
- hidden_states = layer_module(
- hidden_states,
- attention_mask=attention_mask,
- encoder_hidden_states=encoder_hidden_states,
- encoder_attention_mask=encoder_attention_mask,
- **kwargs,
- )
- if self.emb_layer_norm_after:
- hidden_states = self.emb_layer_norm_after(hidden_states)
- return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
- class EvollaSaProtPooler(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
- @auto_docstring
- class EvollaSaProtPreTrainedModel(PreTrainedModel):
- config: SaProtConfig
- _no_split_modules = ["EvollaSaProtLayer"]
- _supports_flash_attn = True
- _supports_sdpa = True
- _supports_flex_attn = True
- _supports_attention_backend = True
- _can_record_outputs = {
- "hidden_states": EvollaSaProtLayer,
- "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
- "cross_attentions": [
- OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
- ],
- }
- def _init_weights(self, module):
- super()._init_weights(module)
- if isinstance(module, EvollaSaProtRotaryEmbedding):
- inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
- init.copy_(module.inv_freq, inv_freq)
- class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
- def __init__(self, config: SaProtConfig):
- super().__init__(config)
- self.embeddings = EvollaSaProtEmbeddings(config)
- self.encoder = EvollaSaProtEncoder(config)
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings.word_embeddings
- def set_input_embeddings(self, value):
- self.embeddings.word_embeddings = value
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_ids: torch.Tensor | None,
- attention_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
- input_shape = input_ids.size()
- batch_size, seq_length = input_shape
- device = input_ids.device
- if attention_mask is None:
- attention_mask = torch.ones(((batch_size, seq_length)), device=device)
- inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask)
- attention_mask = create_bidirectional_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- )
- encoder_outputs = self.encoder(inputs_embeds, attention_mask=attention_mask, **kwargs)
- sequence_output = encoder_outputs[0]
- return BaseModelOutputWithPoolingAndCrossAttentions(
- last_hidden_state=sequence_output,
- hidden_states=encoder_outputs.hidden_states,
- attentions=encoder_outputs.attentions,
- cross_attentions=encoder_outputs.cross_attentions,
- )
- class EvollaSequenceCompressorAttention(nn.Module):
- def __init__(self, dim, dim_head=64, heads=8):
- super().__init__()
- self.scale = dim_head**-0.5
- self.heads = heads
- inner_dim = dim_head * heads
- self.norm_media = nn.LayerNorm(dim)
- self.norm_latents = nn.LayerNorm(dim)
- self.to_q = nn.Linear(dim, inner_dim, bias=False)
- self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
- self.to_out = nn.Linear(inner_dim, dim, bias=False)
- def forward(self, x, latents, mask):
- """
- Args:
- x (torch.Tensor): image features
- shape (b, n1, D)
- latent (torch.Tensor): latent features
- shape (b, n2, D); n2: num of latent tokens
- """
- x = self.norm_media(x)
- latents = self.norm_latents(latents)
- h = self.heads
- q = self.to_q(latents)
- kv_input = torch.cat((x, latents), dim=-2)
- k, v = self.to_kv(kv_input).chunk(
- 2, dim=-1
- ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
- q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3)
- k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3)
- v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3)
- q = q * self.scale # batch_size, num_heads, num_latents, dim_head
- # attention
- sim = torch.matmul(q, k.transpose(-1, -2))
- sim = sim - sim.amax(dim=-1, keepdim=True).detach()
- bs, nh, skd, okd = sim.shape
- ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd)
- mask_exp = mask[:, None, None, :]
- ones_exp = ones[None, :, :, None]
- mask = mask_exp * ones_exp
- sim = sim.masked_fill((1 - mask).bool(), -1e4)
- attn = sim.softmax(dim=-1)
- out = torch.matmul(attn, v)
- out = out.permute(0, 2, 1, 3)
- # [batch, seq, head, features] -> [batch, seq, head*features]
- out = out.reshape(out.size(0), out.size(1), -1)
- return self.to_out(out)
- class EvollaFeedForward(nn.Module):
- def __init__(self, dim, mult=4):
- super().__init__()
- inner_dim = int(dim * mult)
- self.norm = nn.LayerNorm(dim)
- self.fc1 = nn.Linear(dim, inner_dim, bias=False)
- self.activation = nn.GELU()
- self.fc2 = nn.Linear(inner_dim, dim, bias=False)
- def forward(self, x):
- return self.fc2(self.activation(self.fc1(self.norm(x))))
- class EvollaSequenceCompressorResampler(nn.Module):
- def __init__(self, config: EvollaConfig):
- super().__init__()
- protein_repr_dim = config.protein_encoder_config.hidden_size
- self.num_latents = config.resampler_num_latents
- self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True)
- self.layers = nn.ModuleList([])
- for _ in range(config.resampler_depth):
- self.layers.append(
- nn.ModuleList(
- [
- EvollaSequenceCompressorAttention(
- dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads
- ),
- EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult),
- ]
- )
- )
- self.norm = nn.LayerNorm(config.hidden_size)
- self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size)
- def forward(self, embeds, mask):
- b = embeds.shape[0]
- bs, _ = mask.shape # bs, max_protein_length
- latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
- mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents
- # blocks
- ones = torch.ones(b).to(self.latents.device)
- latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d]
- latents = latents.to(embeds.dtype)
- for attn, ff in self.layers:
- latents = attn(embeds, latents, mask) + latents
- latents = ff(latents) + latents
- transformed_feature = self.protein_projector(latents)
- return self.norm(transformed_feature)
- @dataclass
- @auto_docstring
- class EvollaProteinEncoderModelOutput(ModelOutput):
- sequence_compressor_output: torch.FloatTensor | None = None
- last_hidden_state: torch.FloatTensor | None = None
- hidden_states: tuple[torch.FloatTensor, ...] | None = None
- attentions: tuple[torch.FloatTensor, ...] | None = None
- class EvollaProteinEncoder(nn.Module):
- def __init__(self, config: EvollaConfig):
- super().__init__()
- self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config)
- self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config)
- @can_return_tuple
- def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs):
- protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
- protein_embeds = protein_output.last_hidden_state
- sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask)
- return EvollaProteinEncoderModelOutput(
- sequence_compressor_output=sequence_repr,
- last_hidden_state=protein_output.last_hidden_state,
- )
- class EvollaSequenceAlignerCrossAttention(nn.Module):
- def __init__(
- self,
- config,
- protein_encoder_dim: int | None = None,
- structure_encoder_dim: int | None = None,
- msa_encoder_dim: int | None = None,
- ):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.num_attention_heads = config.num_attention_heads
- self.scale = self.num_attention_heads**-0.5
- self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
- attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob
- enable_bias = config.aligner_enable_bias
- ffn_mult = config.aligner_ffn_mult
- self.query = nn.Linear(self.hidden_size, self.all_head_size)
- if protein_encoder_dim is not None:
- self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
- self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
- else:
- self.key_protein = None
- self.value_protein = None
- if structure_encoder_dim is not None:
- self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
- self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
- else:
- self.key_structure = None
- self.value_structure = None
- if msa_encoder_dim is not None:
- self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
- self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
- else:
- self.key_msa = None
- self.value_msa = None
- self.attention_norm = EvollaRMSNorm(self.hidden_size)
- self.dropout = nn.Dropout(attention_probs_dropout_prob)
- self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias)
- self.ff = EvollaFeedForward(self.hidden_size, ffn_mult)
- self.gate_attention = nn.Parameter(torch.tensor([0.0]))
- self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
- def cross_attention(
- self,
- query_states,
- protein_key_value_states,
- structure_key_value_states,
- msa_key_value_states,
- query_attn_mask,
- protein_kv_attn_mask,
- structure_kv_attn_mask,
- msa_kv_attn_mask,
- ):
- """
- query_states: text
- key_value_states: protein
- query_states: [bs, query_seq_len, dim]
- key_value_states: [bs, kv_seq_len, dim]
- query_attn_mask: [bs, query_seq_len]
- kv_attn_mask: [bs, kv_seq_len]
- """
- # Concatenate protein and structure
- kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
- kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
- if not kv_attn_mask:
- raise ValueError("At least one modality should be provided for cross attention.")
- kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
- query_layer = self.attention_norm(query_states)
- # Warning: This place might cause issues, refers to
- # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
- # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
- # Apply linear transformation to input_query, input_key, and input_value
- query_layer = self.query(query_layer) # [bs, querylength, dim]
- if self.key_protein is not None and self.value_protein is not None:
- protein_key_value_states = protein_key_value_states.to(query_states)
- key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim]
- value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim]
- else:
- key_layer_protein = None
- value_layer_protein = None
- if self.key_structure is not None and self.value_structure is not None:
- structure_key_value_states = structure_key_value_states.to(query_states)
- key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim]
- value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim]
- else:
- key_layer_structure = None
- value_layer_structure = None
- if self.key_msa is not None and self.value_msa is not None:
- msa_key_value_states = msa_key_value_states.to(query_states)
- key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
- value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim]
- else:
- key_layer_msa = None
- value_layer_msa = None
- key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
- key_layer = [_ for _ in key_layer if _ is not None]
- key_layer = torch.cat(key_layer, dim=1)
- value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
- value_layer = [_ for _ in value_layer if _ is not None]
- value_layer = torch.cat(value_layer, dim=1)
- new_query_layer_shape = query_layer.size()[:-1] + (
- self.num_attention_heads,
- self.attention_head_size,
- )
- query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3)
- new_key_layer_shape = key_layer.size()[:-1] + (
- self.num_attention_heads,
- self.attention_head_size,
- )
- key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3)
- new_value_layer_shape = value_layer.size()[:-1] + (
- self.num_attention_heads,
- self.attention_head_size,
- )
- value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3)
- query_layer = query_layer * self.scale
- # attention_mask: [bs, 1, querylength, keylength]
- if query_attn_mask is None:
- query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device)
- attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
- # Compute the scaled dot-product attention scores
- attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength]
- attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score
- attention_scores = attn_weights.masked_fill(
- (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
- ) # [bs, numheads, querylength, keylength]
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
- # attention_probs_dropped = self.dropout(attention_probs)
- context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads]
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- context_layer = self.out_proj(context_layer)
- return context_layer
- def forward(
- self,
- query_states,
- protein_kv_states,
- structure_kv_states,
- msa_kv_states,
- query_attn_mask,
- protein_kv_attn_mask=None,
- structure_kv_attn_mask=None,
- msa_kv_attn_mask=None,
- protein_batch_mask=None,
- structure_batch_mask=None,
- msa_batch_mask=None,
- past_key_values=None,
- ):
- if protein_kv_states is not None:
- bs, protein_kv_seq_len, dim = protein_kv_states.shape
- if protein_kv_attn_mask is None:
- protein_kv_attn_mask = (
- torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device)
- * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
- ).to(protein_kv_states.device)
- else:
- protein_kv_attn_mask = None
- if structure_kv_states is not None:
- bs, structure_kv_seq_len, dim = structure_kv_states.shape
- if structure_kv_attn_mask is None:
- structure_kv_attn_mask = (
- torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device)
- * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
- ).to(structure_kv_states.device)
- else:
- structure_kv_attn_mask = None
- if msa_kv_states is not None:
- bs, msa_kv_seq_len, dim = msa_kv_states.shape
- if msa_kv_attn_mask is None:
- msa_kv_attn_mask = (
- torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device)
- * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
- ).to(msa_kv_states.device)
- else:
- msa_kv_attn_mask = None
- hidden_states = query_states
- # only when there's at least one valid modality, crossattention will be performed
- if (
- (protein_kv_states is not None and protein_kv_attn_mask.any())
- or (structure_kv_states is not None and structure_kv_attn_mask.any())
- or (msa_kv_states is not None and msa_kv_attn_mask.any())
- ):
- residual = hidden_states
- hidden_states = self.cross_attention(
- query_states=hidden_states,
- protein_key_value_states=protein_kv_states,
- structure_key_value_states=structure_kv_states,
- msa_key_value_states=msa_kv_states,
- query_attn_mask=query_attn_mask,
- protein_kv_attn_mask=protein_kv_attn_mask,
- structure_kv_attn_mask=structure_kv_attn_mask,
- msa_kv_attn_mask=msa_kv_attn_mask,
- ) # [bs, query_seq_len, dim]
- # tanh gate
- hidden_states = torch.tanh(self.gate_attention) * hidden_states
- hidden_states = residual + hidden_states # input_query
- residual = hidden_states
- hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
- hidden_states = residual + hidden_states
- return hidden_states
- @use_kernel_forward_from_hub("RMSNorm")
- class EvollaRMSNorm(nn.Module):
- def __init__(self, hidden_size, eps: float = 1e-6) -> None:
- """
- EvollaRMSNorm is equivalent to T5LayerNorm
- """
- super().__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.variance_epsilon = eps
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- input_dtype = hidden_states.dtype
- hidden_states = hidden_states.to(torch.float32)
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
- return self.weight * hidden_states.to(input_dtype)
- def extra_repr(self):
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
- class EvollaRotaryEmbedding(nn.Module):
- inv_freq: torch.Tensor # fix linting for `register_buffer`
- def __init__(self, config: EvollaConfig, device=None):
- super().__init__()
- self.max_seq_len_cached = config.max_position_embeddings
- self.original_max_seq_len = config.max_position_embeddings
- self.config = config
- self.rope_type = self.config.rope_parameters["rope_type"]
- rope_init_fn: Callable = self.compute_default_rope_parameters
- if self.rope_type != "default":
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
- self.register_buffer("inv_freq", inv_freq, persistent=False)
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
- @staticmethod
- def compute_default_rope_parameters(
- config: EvollaConfig | None = None,
- device: Optional["torch.device"] = None,
- seq_len: int | None = None,
- ) -> tuple["torch.Tensor", float]:
- """
- Computes the inverse frequencies according to the original RoPE implementation
- Args:
- config ([`~transformers.PreTrainedConfig`]):
- The model configuration.
- device (`torch.device`):
- The device to use for initialization of the inverse frequencies.
- seq_len (`int`, *optional*):
- The current sequence length. Unused for this type of RoPE.
- Returns:
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
- """
- base = config.rope_parameters["rope_theta"]
- dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
- attention_factor = 1.0 # Unused in this type of RoPE
- # Compute the inverse frequencies
- inv_freq = 1.0 / (
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
- )
- return inv_freq, attention_factor
- @torch.no_grad()
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
- def forward(self, x, position_ids):
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
- position_ids_expanded = position_ids[:, None, :].float()
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
- emb = torch.cat((freqs, freqs), dim=-1)
- cos = emb.cos() * self.attention_scaling
- sin = emb.sin() * self.attention_scaling
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
- class EvollaMLP(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.hidden_size = config.hidden_size
- self.intermediate_size = config.intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
- self.act_fn = ACT2FN[config.hidden_act]
- def forward(self, x):
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
- return down_proj
- def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- @use_kernel_func_from_hub("rotary_pos_emb")
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
- """Applies Rotary Position Embedding to the query and key tensors.
- Args:
- q (`torch.Tensor`): The query tensor.
- k (`torch.Tensor`): The key tensor.
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
- sin (`torch.Tensor`): The sine part of the rotary embedding.
- unsqueeze_dim (`int`, *optional*, defaults to 1):
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
- Returns:
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
- """
- cos = cos.unsqueeze(unsqueeze_dim)
- sin = sin.unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
- """
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
- """
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
- if n_rep == 1:
- return hidden_states
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
- @use_kernelized_func(apply_rotary_pos_emb)
- class EvollaAttention(nn.Module):
- """Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config: EvollaConfig, layer_idx: int):
- super().__init__()
- self.config = config
- self.layer_idx = layer_idx
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
- self.scaling = self.head_dim**-0.5
- self.attention_dropout = config.attention_dropout
- self.is_causal = True
- self.q_proj = nn.Linear(
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
- )
- self.k_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.v_proj = nn.Linear(
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
- )
- self.o_proj = nn.Linear(
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- past_key_values: Cache | None = None,
- **kwargs: Unpack[TransformersKwargs],
- ) -> tuple[torch.Tensor, torch.Tensor]:
- input_shape = hidden_states.shape[:-1]
- hidden_shape = (*input_shape, -1, self.head_dim)
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
- cos, sin = position_embeddings
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
- if past_key_values is not None:
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
- attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
- self.config._attn_implementation, eager_attention_forward
- )
- attn_output, attn_weights = attention_interface(
- self,
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout=0.0 if not self.training else self.attention_dropout,
- scaling=self.scaling,
- **kwargs,
- )
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
- attn_output = self.o_proj(attn_output)
- return attn_output, attn_weights
- class EvollaDecoderLayer(GradientCheckpointingLayer):
- def __init__(self, config: EvollaConfig, layer_idx: int):
- super().__init__()
- self.hidden_size = config.hidden_size
- self.self_attn = EvollaAttention(config=config, layer_idx=layer_idx)
- self.mlp = EvollaMLP(config)
- self.input_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.post_attention_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0:
- self.adapter = EvollaSequenceAlignerCrossAttention(
- config,
- protein_encoder_dim=config.hidden_size,
- )
- def forward(
- self,
- hidden_states: torch.Tensor,
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- use_cache: bool | None = False,
- protein_kv_states: torch.Tensor | None = None,
- structure_kv_states: torch.Tensor | None = None,
- msa_kv_states: torch.Tensor | None = None,
- protein_batch_mask: torch.Tensor | None = None,
- structure_batch_mask: torch.Tensor | None = None,
- msa_batch_mask: torch.Tensor | None = None,
- query_attn_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> torch.Tensor:
- residual = hidden_states
- hidden_states = self.input_layernorm(hidden_states)
- # Self Attention
- hidden_states, _ = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = residual + hidden_states
- # Fully Connected
- residual = hidden_states
- hidden_states = self.post_attention_layernorm(hidden_states)
- hidden_states = self.mlp(hidden_states)
- hidden_states = residual + hidden_states
- if hasattr(self, "adapter"):
- hidden_states = self.adapter(
- query_states=hidden_states,
- protein_kv_states=protein_kv_states,
- structure_kv_states=structure_kv_states,
- msa_kv_states=msa_kv_states,
- query_attn_mask=query_attn_mask,
- protein_batch_mask=protein_batch_mask,
- structure_batch_mask=structure_batch_mask,
- msa_batch_mask=msa_batch_mask,
- )
- return hidden_states
- @auto_docstring
- class EvollaPreTrainedModel(PreTrainedModel):
- config: EvollaConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = [
- "EvollaDecoderLayer",
- "EvollaSequenceCompressorResampler",
- "EvollaSequenceAlignerCrossAttention",
- ]
- _skip_keys_device_placement = ["past_key_values"]
- _supports_flash_attn = False # see dependency on `EvollaSequenceCompressorResampler`
- _supports_sdpa = True
- _supports_flex_attn = False # see dependency on `EvollaSequenceCompressorResampler`
- _can_compile_fullgraph = True
- _supports_attention_backend = False
- _can_record_outputs = {
- "hidden_states": EvollaDecoderLayer,
- "attentions": EvollaAttention,
- }
- @torch.no_grad()
- def _init_weights(self, module):
- std = self.config.initializer_range
- super()._init_weights(module)
- if isinstance(module, EvollaSequenceAlignerCrossAttention):
- init.zeros_(module.gate_attention)
- init.zeros_(module.gate_ffw)
- init.ones_(module.attention_norm.weight)
- elif isinstance(module, EvollaSequenceCompressorResampler):
- init.normal_(module.latents, mean=0.0, std=std)
- class EvollaModel(EvollaPreTrainedModel):
- def __init__(self, config: EvollaConfig):
- super().__init__(config)
- self.padding_idx = config.pad_token_id
- self.vocab_size = config.vocab_size
- self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
- self.protein_encoder = EvollaProteinEncoder(config=config)
- self.layers = nn.ModuleList(
- [
- EvollaDecoderLayer(
- config=config,
- layer_idx=layer_idx,
- )
- for layer_idx in range(config.num_hidden_layers)
- ]
- )
- self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
- self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
- self.rotary_emb = EvollaRotaryEmbedding(config=config)
- self.post_init()
- def get_input_embeddings(self):
- return self.embed_tokens
- def set_input_embeddings(self, value):
- self.embed_tokens = value
- @auto_docstring
- @merge_with_config_defaults
- @capture_outputs
- def forward(
- self,
- input_ids: torch.LongTensor | None = None,
- attention_mask: torch.Tensor | None = None,
- position_ids: torch.LongTensor | None = None,
- past_key_values: Cache | None = None,
- inputs_embeds: torch.FloatTensor | None = None,
- use_cache: bool | None = None,
- protein_input_ids: torch.LongTensor | None = None,
- protein_attention_mask: torch.Tensor | None = None,
- structure_feats: torch.FloatTensor | None = None,
- msa_feats: torch.FloatTensor | None = None,
- structure_batch_mask: torch.Tensor | None = None,
- msa_batch_mask: torch.Tensor | None = None,
- **kwargs,
- ) -> tuple | BaseModelOutputWithPast:
- r"""
- protein_input_ids (torch.LongTensor):
- The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
- protein_attention_mask (torch.Tensor):
- The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
- structure_feats (torch.FloatTensor):
- The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
- msa_feats (torch.FloatTensor):
- The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
- structure_batch_mask (torch.Tensor):
- The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
- msa_batch_mask (torch.Tensor):
- The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
- """
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
- if use_cache and past_key_values is None:
- past_key_values = DynamicCache(config=self.config)
- if position_ids is None:
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
- position_ids = position_ids.unsqueeze(0)
- protein_feats = None
- protein_batch_mask = None
- # If provided, actually compute them
- if protein_input_ids is not None and protein_attention_mask is not None:
- protein_outputs = self.protein_encoder(
- input_ids=protein_input_ids,
- attention_mask=protein_attention_mask,
- )
- protein_feats = protein_outputs.sequence_compressor_output
- protein_batch_mask = torch.ones(
- protein_input_ids.shape[0],
- device=protein_input_ids.device,
- dtype=torch.bool,
- )
- causal_mask = create_causal_mask(
- config=self.config,
- inputs_embeds=inputs_embeds,
- attention_mask=attention_mask,
- past_key_values=past_key_values,
- )
- hidden_states = inputs_embeds
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
- for decoder_layer in self.layers:
- hidden_states = decoder_layer(
- hidden_states,
- attention_mask=causal_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- protein_kv_states=protein_feats,
- structure_kv_states=structure_feats,
- msa_kv_states=msa_feats,
- protein_batch_mask=protein_batch_mask,
- structure_batch_mask=structure_batch_mask,
- msa_batch_mask=msa_batch_mask,
- query_attn_mask=attention_mask,
- position_embeddings=position_embeddings,
- **kwargs,
- )
- hidden_states = self.norm(hidden_states)
- output = BaseModelOutputWithPast(
- last_hidden_state=hidden_states,
- past_key_values=past_key_values,
- )
- return output
- class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin):
- def __init__(self, config):
- super().__init__(config)
- self.model = EvollaModel(config)
- self.vocab_size = config.vocab_size
- self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
- self.post_init()
- def get_input_embeddings(self):
- return self.model.get_input_embeddings()
- def set_input_embeddings(self, value):
- return self.model.set_input_embeddings(value)
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: torch.LongTensor | None = None, # text input ids
- attention_mask: torch.Tensor | None = None, # text attention mask
- inputs_embeds: torch.FloatTensor | None = None, # text input embeddings
- labels: torch.LongTensor | None = None,
- protein_input_ids: torch.LongTensor | None = None,
- protein_attention_mask: torch.Tensor | None = None,
- use_cache: bool | None = None,
- logits_to_keep: int | torch.Tensor = 0,
- **kwargs,
- ):
- r"""
- protein_input_ids (torch.LongTensor):
- The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
- protein_attention_mask (torch.Tensor):
- The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
- Example:
- ```python
- >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
- >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
- >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")
- >>> protein_information = {
- "aa_seq": "your amino acid sequence",
- "foldseek": "your foldseek sequence",
- }
- >>> question = "What is the function of this protein?"
- >>> message = [
- {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
- {"role": "user", "content": question},
- ]
- >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
- >>> outputs = model.generate(**inputs)
- >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
- ```"""
- outputs: BaseModelOutputWithPast = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- inputs_embeds=inputs_embeds,
- protein_input_ids=protein_input_ids,
- protein_attention_mask=protein_attention_mask,
- use_cache=use_cache,
- **kwargs,
- )
- hidden_states = outputs.last_hidden_state
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
- logits = self.lm_head(hidden_states[:, slice_indices, :])
- loss = None
- if labels is not None:
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
- lm_outputs = CausalLMOutputWithPast(
- loss=loss,
- logits=logits,
- past_key_values=outputs.past_key_values,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- return lm_outputs
- __all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"]
|