modeling_evolla.py 61 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/evolla/modular_evolla.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_evolla.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 Westlake Representational Learning Lab (Fajie Yuan Lab) team and the HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import math
  21. from collections.abc import Callable
  22. from dataclasses import dataclass
  23. from typing import Optional
  24. import torch
  25. from torch import nn
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache
  29. from ...generation import GenerationMixin
  30. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  31. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import (
  34. BaseModelOutputWithCrossAttentions,
  35. BaseModelOutputWithPast,
  36. BaseModelOutputWithPoolingAndCrossAttentions,
  37. CausalLMOutputWithPast,
  38. ModelOutput,
  39. )
  40. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  41. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  42. from ...processing_utils import Unpack
  43. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  44. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  45. from ...utils.output_capturing import OutputRecorder, capture_outputs
  46. from .configuration_evolla import EvollaConfig, SaProtConfig
  47. def create_position_ids_from_input_ids(input_ids, padding_idx):
  48. """
  49. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  50. are ignored. This is modified from fairseq's `utils.make_positions`.
  51. Args:
  52. x: torch.Tensor x:
  53. Returns: torch.Tensor
  54. """
  55. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  56. mask = input_ids.ne(padding_idx).int()
  57. incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
  58. return incremental_indices.long() + padding_idx
  59. class EvollaSaProtEmbeddings(nn.Module):
  60. """
  61. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  62. """
  63. def __init__(self, config):
  64. super().__init__()
  65. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  66. if config.emb_layer_norm_before:
  67. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  68. else:
  69. self.layer_norm = None
  70. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  71. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  72. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  73. self.register_buffer(
  74. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  75. )
  76. self.padding_idx = config.pad_token_id
  77. if self.position_embedding_type == "absolute":
  78. self.position_embeddings = nn.Embedding(
  79. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  80. )
  81. self.token_dropout = config.token_dropout
  82. self.mask_token_id = config.mask_token_id
  83. # remove the position_ids in EsmEmbeddings
  84. self.position_ids = None
  85. def forward(
  86. self,
  87. input_ids=None,
  88. attention_mask=None,
  89. position_ids=None,
  90. inputs_embeds=None,
  91. ):
  92. if position_ids is None:
  93. if input_ids is not None:
  94. # Create the position ids from the input token ids. Any padded tokens remain padded.
  95. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx)
  96. else:
  97. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  98. if inputs_embeds is None:
  99. inputs_embeds = self.word_embeddings(input_ids)
  100. # Note that if we want to support EVOLLA_SA_PROT-1 (not 1b!) in future then we need to support an
  101. # embedding_scale factor here.
  102. embeddings = inputs_embeds
  103. # Matt: EVOLLA_SA_PROT has the option to handle masking in MLM in a slightly unusual way. If the token_dropout
  104. # flag is False then it is handled in the same was as BERT/RoBERTa. If it is set to True, however,
  105. # masked tokens are treated as if they were selected for input dropout and zeroed out.
  106. # This "mask-dropout" is compensated for when masked tokens are not present, by scaling embeddings by
  107. # a factor of (fraction of unmasked tokens during training) / (fraction of unmasked tokens in sample).
  108. # This is analogous to the way that dropout layers scale down outputs during evaluation when not
  109. # actually dropping out values (or, equivalently, scale up their un-dropped outputs in training).
  110. if self.token_dropout and input_ids is not None:
  111. embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0.0)
  112. mask_ratio_train = 0.15 * 0.8 # Hardcoded as the ratio used in all EVOLLA_SA_PROT model training runs
  113. src_lengths = attention_mask.sum(-1) if attention_mask is not None else input_ids.shape[1]
  114. mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
  115. embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
  116. embeddings.dtype
  117. )
  118. if self.position_embedding_type == "absolute":
  119. position_embeddings = self.position_embeddings(position_ids)
  120. embeddings = embeddings + position_embeddings
  121. if self.layer_norm is not None:
  122. embeddings = self.layer_norm(embeddings)
  123. if attention_mask is not None:
  124. embeddings = (embeddings * attention_mask.unsqueeze(-1)).to(embeddings.dtype)
  125. # Matt: I think this line was copied incorrectly from BERT, disabling it for now.
  126. # embeddings = self.dropout(embeddings)
  127. return embeddings
  128. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  129. """
  130. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  131. Args:
  132. inputs_embeds: torch.Tensor
  133. Returns: torch.Tensor
  134. """
  135. input_shape = inputs_embeds.size()[:-1]
  136. sequence_length = input_shape[1]
  137. position_ids = torch.arange(
  138. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  139. )
  140. return position_ids.unsqueeze(0).expand(input_shape)
  141. def rotate_half_esm(x):
  142. x1, x2 = x.chunk(2, dim=-1)
  143. return torch.cat((-x2, x1), dim=-1)
  144. def apply_rotary_pos_emb_esm(x, cos, sin):
  145. cos = cos[:, :, : x.shape[-2], :]
  146. sin = sin[:, :, : x.shape[-2], :]
  147. return (x * cos) + (rotate_half_esm(x) * sin)
  148. class EvollaSaProtRotaryEmbedding(nn.Module):
  149. """
  150. Rotary position embeddings based on those in
  151. [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
  152. matrices which depend on their relative positions.
  153. """
  154. inv_freq: torch.Tensor # fix linting for `register_buffer`
  155. def __init__(self, dim: int):
  156. super().__init__()
  157. self.dim = dim
  158. # Generate and save the inverse frequency buffer (non trainable)
  159. inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
  160. self.register_buffer("inv_freq", inv_freq)
  161. self._seq_len_cached = None
  162. self._cos_cached = None
  163. self._sin_cached = None
  164. def _update_cos_sin_tables(self, x, seq_dimension=2):
  165. seq_len = x.shape[seq_dimension]
  166. # Reset the tables if the sequence length has changed,
  167. # or if we're on a new device (possibly due to tracing for instance)
  168. if seq_len != self._seq_len_cached or self._cos_cached.device != x.device:
  169. self._seq_len_cached = seq_len
  170. t = torch.arange(x.shape[seq_dimension], device=x.device).type_as(self.inv_freq)
  171. freqs = torch.outer(t, self.inv_freq)
  172. emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
  173. self._cos_cached = emb.cos()[None, None, :, :]
  174. self._sin_cached = emb.sin()[None, None, :, :]
  175. return self._cos_cached, self._sin_cached
  176. def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  177. self._cos_cached, self._sin_cached = self._update_cos_sin_tables(k, seq_dimension=-2)
  178. return (
  179. apply_rotary_pos_emb_esm(q, self._cos_cached, self._sin_cached).to(dtype=q.dtype),
  180. apply_rotary_pos_emb_esm(k, self._cos_cached, self._sin_cached).to(dtype=k.dtype),
  181. )
  182. def eager_attention_forward(
  183. module: nn.Module,
  184. query: torch.Tensor,
  185. key: torch.Tensor,
  186. value: torch.Tensor,
  187. attention_mask: torch.Tensor | None,
  188. scaling: float | None = None,
  189. dropout: float = 0.0,
  190. **kwargs: Unpack[TransformersKwargs],
  191. ):
  192. if scaling is None:
  193. scaling = query.size(-1) ** -0.5
  194. # Take the dot product between "query" and "key" to get the raw attention scores.
  195. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  196. if attention_mask is not None:
  197. attn_weights = attn_weights + attention_mask
  198. attn_weights = nn.functional.softmax(attn_weights, dim=-1)
  199. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  200. attn_output = torch.matmul(attn_weights, value)
  201. attn_output = attn_output.transpose(1, 2).contiguous()
  202. return attn_output, attn_weights
  203. class EvollaSaProtSelfAttention(nn.Module):
  204. def __init__(self, config, position_embedding_type=None, layer_idx=None, is_cross_attention=False):
  205. super().__init__()
  206. self.config = config
  207. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  208. raise ValueError(
  209. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  210. f"heads ({config.num_attention_heads})"
  211. )
  212. self.num_attention_heads = config.num_attention_heads
  213. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  214. self.all_head_size = self.num_attention_heads * self.attention_head_size
  215. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  216. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  217. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  218. self.dropout = config.attention_probs_dropout_prob
  219. self.rotary_embeddings = None
  220. self.position_embedding_type = position_embedding_type or getattr(
  221. config, "position_embedding_type", "absolute"
  222. )
  223. if self.position_embedding_type == "rotary":
  224. self.rotary_embeddings = EvollaSaProtRotaryEmbedding(dim=self.attention_head_size)
  225. self.is_decoder = config.is_decoder
  226. self.layer_idx = layer_idx
  227. self.scaling = 1.0
  228. self.is_causal = self.is_decoder and not is_cross_attention
  229. def forward(
  230. self,
  231. hidden_states: torch.Tensor,
  232. attention_mask: torch.FloatTensor | None = None,
  233. encoder_hidden_states: torch.FloatTensor | None = None,
  234. encoder_attention_mask: torch.FloatTensor | None = None,
  235. **kwargs: Unpack[TransformersKwargs],
  236. ) -> tuple[torch.Tensor]:
  237. input_shape = hidden_states.shape[:-1]
  238. hidden_shape = (*input_shape, -1, self.attention_head_size)
  239. query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  240. is_cross_attention = encoder_hidden_states is not None
  241. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  242. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  243. key_layer = self.key(current_states).view(hidden_shape).transpose(1, 2)
  244. value_layer = self.value(current_states).view(hidden_shape).transpose(1, 2)
  245. # Matt: Our BERT model (which this code was derived from) scales attention logits down by sqrt(head_dim).
  246. # EVOLLA_SA_PROT scales the query down by the same factor instead. Modulo numerical stability these are equivalent,
  247. # but not when rotary embeddings get involved. Therefore, we scale the query here to match the original
  248. # EVOLLA_SA_PROT code and fix rotary embeddings.
  249. query_layer = query_layer * self.attention_head_size**-0.5
  250. if self.position_embedding_type == "rotary":
  251. query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
  252. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  253. self.config._attn_implementation, eager_attention_forward
  254. )
  255. attn_output, attn_weights = attention_interface(
  256. self,
  257. query_layer,
  258. key_layer,
  259. value_layer,
  260. attention_mask,
  261. dropout=0.0 if not self.training else self.dropout,
  262. scaling=self.scaling,
  263. **kwargs,
  264. )
  265. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  266. return attn_output, attn_weights
  267. class EvollaSaProtSelfOutput(nn.Module):
  268. def __init__(self, config):
  269. super().__init__()
  270. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  271. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  272. def forward(self, hidden_states, input_tensor):
  273. hidden_states = self.dense(hidden_states)
  274. hidden_states = self.dropout(hidden_states)
  275. hidden_states = hidden_states + input_tensor
  276. return hidden_states
  277. class EvollaSaProtAttention(nn.Module):
  278. def __init__(self, config, layer_idx=None, is_cross_attention=False):
  279. super().__init__()
  280. self.self = EvollaSaProtSelfAttention(config, layer_idx=layer_idx, is_cross_attention=is_cross_attention)
  281. self.output = EvollaSaProtSelfOutput(config)
  282. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  283. def forward(
  284. self,
  285. hidden_states,
  286. attention_mask=None,
  287. encoder_hidden_states=None,
  288. encoder_attention_mask=None,
  289. **kwargs: Unpack[TransformersKwargs],
  290. ):
  291. hidden_states_ln = self.LayerNorm(hidden_states)
  292. attn_output, _ = self.self(
  293. hidden_states_ln,
  294. attention_mask=attention_mask,
  295. encoder_hidden_states=encoder_hidden_states,
  296. encoder_attention_mask=encoder_attention_mask,
  297. **kwargs,
  298. )
  299. attn_output = self.output(attn_output, hidden_states)
  300. return attn_output
  301. def gelu(x):
  302. """
  303. This is the gelu implementation from the original EVOLLA_SA_PROT repo. Using F.gelu yields subtly wrong results.
  304. """
  305. return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
  306. class EvollaSaProtIntermediate(nn.Module):
  307. def __init__(self, config):
  308. super().__init__()
  309. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  310. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  311. hidden_states = self.dense(hidden_states)
  312. hidden_states = gelu(hidden_states)
  313. return hidden_states
  314. class EvollaSaProtOutput(nn.Module):
  315. def __init__(self, config):
  316. super().__init__()
  317. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  318. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  319. def forward(self, hidden_states, input_tensor):
  320. hidden_states = self.dense(hidden_states)
  321. hidden_states = self.dropout(hidden_states)
  322. hidden_states = hidden_states + input_tensor
  323. return hidden_states
  324. class EvollaSaProtLayer(GradientCheckpointingLayer):
  325. def __init__(self, config):
  326. super().__init__()
  327. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  328. self.seq_len_dim = 1
  329. self.attention = EvollaSaProtAttention(config)
  330. self.is_decoder = config.is_decoder
  331. self.add_cross_attention = config.add_cross_attention
  332. if self.add_cross_attention:
  333. if not self.is_decoder:
  334. raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added")
  335. self.crossattention = EvollaSaProtAttention(config, is_cross_attention=True)
  336. self.intermediate = EvollaSaProtIntermediate(config)
  337. self.output = EvollaSaProtOutput(config)
  338. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  339. def forward(
  340. self,
  341. hidden_states,
  342. attention_mask=None,
  343. encoder_hidden_states=None,
  344. encoder_attention_mask=None,
  345. **kwargs: Unpack[TransformersKwargs],
  346. ):
  347. attention_output = self.attention(
  348. hidden_states,
  349. attention_mask=attention_mask,
  350. **kwargs,
  351. )
  352. if self.is_decoder and encoder_hidden_states is not None:
  353. if not hasattr(self, "crossattention"):
  354. raise AttributeError(
  355. f"If `encoder_hidden_states` are passed, {self} has to be instantiated"
  356. " with cross-attention layers by setting `config.add_cross_attention=True`"
  357. )
  358. attention_output = self.crossattention(
  359. attention_output,
  360. attention_mask=attention_mask,
  361. encoder_hidden_states=encoder_hidden_states,
  362. encoder_attention_mask=encoder_attention_mask,
  363. **kwargs,
  364. )
  365. layer_output = self.feed_forward_chunk(attention_output)
  366. return layer_output
  367. def feed_forward_chunk(self, attention_output):
  368. attention_output_ln = self.LayerNorm(attention_output)
  369. intermediate_output = self.intermediate(attention_output_ln)
  370. layer_output = self.output(intermediate_output, attention_output)
  371. return layer_output
  372. class EvollaSaProtEncoder(nn.Module):
  373. def __init__(self, config):
  374. super().__init__()
  375. self.config = config
  376. self.layer = nn.ModuleList([EvollaSaProtLayer(config) for _ in range(config.num_hidden_layers)])
  377. self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  378. self.gradient_checkpointing = False
  379. @can_return_tuple
  380. def forward(
  381. self,
  382. hidden_states,
  383. attention_mask=None,
  384. encoder_hidden_states=None,
  385. encoder_attention_mask=None,
  386. **kwargs: Unpack[TransformersKwargs],
  387. ):
  388. for i, layer_module in enumerate(self.layer):
  389. hidden_states = layer_module(
  390. hidden_states,
  391. attention_mask=attention_mask,
  392. encoder_hidden_states=encoder_hidden_states,
  393. encoder_attention_mask=encoder_attention_mask,
  394. **kwargs,
  395. )
  396. if self.emb_layer_norm_after:
  397. hidden_states = self.emb_layer_norm_after(hidden_states)
  398. return BaseModelOutputWithCrossAttentions(last_hidden_state=hidden_states)
  399. class EvollaSaProtPooler(nn.Module):
  400. def __init__(self, config):
  401. super().__init__()
  402. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  403. self.activation = nn.Tanh()
  404. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  405. # We "pool" the model by simply taking the hidden state corresponding
  406. # to the first token.
  407. first_token_tensor = hidden_states[:, 0]
  408. pooled_output = self.dense(first_token_tensor)
  409. pooled_output = self.activation(pooled_output)
  410. return pooled_output
  411. @auto_docstring
  412. class EvollaSaProtPreTrainedModel(PreTrainedModel):
  413. config: SaProtConfig
  414. _no_split_modules = ["EvollaSaProtLayer"]
  415. _supports_flash_attn = True
  416. _supports_sdpa = True
  417. _supports_flex_attn = True
  418. _supports_attention_backend = True
  419. _can_record_outputs = {
  420. "hidden_states": EvollaSaProtLayer,
  421. "attentions": [OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="attention")],
  422. "cross_attentions": [
  423. OutputRecorder(EvollaSaProtSelfAttention, index=1, layer_name="crossattention"),
  424. ],
  425. }
  426. def _init_weights(self, module):
  427. super()._init_weights(module)
  428. if isinstance(module, EvollaSaProtRotaryEmbedding):
  429. inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
  430. init.copy_(module.inv_freq, inv_freq)
  431. class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
  432. def __init__(self, config: SaProtConfig):
  433. super().__init__(config)
  434. self.embeddings = EvollaSaProtEmbeddings(config)
  435. self.encoder = EvollaSaProtEncoder(config)
  436. self.post_init()
  437. def get_input_embeddings(self):
  438. return self.embeddings.word_embeddings
  439. def set_input_embeddings(self, value):
  440. self.embeddings.word_embeddings = value
  441. @merge_with_config_defaults
  442. @capture_outputs
  443. def forward(
  444. self,
  445. input_ids: torch.Tensor | None,
  446. attention_mask: torch.Tensor | None = None,
  447. **kwargs,
  448. ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
  449. input_shape = input_ids.size()
  450. batch_size, seq_length = input_shape
  451. device = input_ids.device
  452. if attention_mask is None:
  453. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  454. inputs_embeds = self.embeddings(input_ids=input_ids, attention_mask=attention_mask)
  455. attention_mask = create_bidirectional_mask(
  456. config=self.config,
  457. inputs_embeds=inputs_embeds,
  458. attention_mask=attention_mask,
  459. )
  460. encoder_outputs = self.encoder(inputs_embeds, attention_mask=attention_mask, **kwargs)
  461. sequence_output = encoder_outputs[0]
  462. return BaseModelOutputWithPoolingAndCrossAttentions(
  463. last_hidden_state=sequence_output,
  464. hidden_states=encoder_outputs.hidden_states,
  465. attentions=encoder_outputs.attentions,
  466. cross_attentions=encoder_outputs.cross_attentions,
  467. )
  468. class EvollaSequenceCompressorAttention(nn.Module):
  469. def __init__(self, dim, dim_head=64, heads=8):
  470. super().__init__()
  471. self.scale = dim_head**-0.5
  472. self.heads = heads
  473. inner_dim = dim_head * heads
  474. self.norm_media = nn.LayerNorm(dim)
  475. self.norm_latents = nn.LayerNorm(dim)
  476. self.to_q = nn.Linear(dim, inner_dim, bias=False)
  477. self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
  478. self.to_out = nn.Linear(inner_dim, dim, bias=False)
  479. def forward(self, x, latents, mask):
  480. """
  481. Args:
  482. x (torch.Tensor): image features
  483. shape (b, n1, D)
  484. latent (torch.Tensor): latent features
  485. shape (b, n2, D); n2: num of latent tokens
  486. """
  487. x = self.norm_media(x)
  488. latents = self.norm_latents(latents)
  489. h = self.heads
  490. q = self.to_q(latents)
  491. kv_input = torch.cat((x, latents), dim=-2)
  492. k, v = self.to_kv(kv_input).chunk(
  493. 2, dim=-1
  494. ) # each: batch_size, max_protein_length+num_latents, dim_head*num_heads
  495. q = q.view(q.size(0), q.size(1), h, -1).permute(0, 2, 1, 3)
  496. k = k.view(k.size(0), k.size(1), h, -1).permute(0, 2, 1, 3)
  497. v = v.view(v.size(0), v.size(1), h, -1).permute(0, 2, 1, 3)
  498. q = q * self.scale # batch_size, num_heads, num_latents, dim_head
  499. # attention
  500. sim = torch.matmul(q, k.transpose(-1, -2))
  501. sim = sim - sim.amax(dim=-1, keepdim=True).detach()
  502. bs, nh, skd, okd = sim.shape
  503. ones = torch.ones(nh, skd).to(mask.device) # Create a tensor of ones with shape (nh, skd)
  504. mask_exp = mask[:, None, None, :]
  505. ones_exp = ones[None, :, :, None]
  506. mask = mask_exp * ones_exp
  507. sim = sim.masked_fill((1 - mask).bool(), -1e4)
  508. attn = sim.softmax(dim=-1)
  509. out = torch.matmul(attn, v)
  510. out = out.permute(0, 2, 1, 3)
  511. # [batch, seq, head, features] -> [batch, seq, head*features]
  512. out = out.reshape(out.size(0), out.size(1), -1)
  513. return self.to_out(out)
  514. class EvollaFeedForward(nn.Module):
  515. def __init__(self, dim, mult=4):
  516. super().__init__()
  517. inner_dim = int(dim * mult)
  518. self.norm = nn.LayerNorm(dim)
  519. self.fc1 = nn.Linear(dim, inner_dim, bias=False)
  520. self.activation = nn.GELU()
  521. self.fc2 = nn.Linear(inner_dim, dim, bias=False)
  522. def forward(self, x):
  523. return self.fc2(self.activation(self.fc1(self.norm(x))))
  524. class EvollaSequenceCompressorResampler(nn.Module):
  525. def __init__(self, config: EvollaConfig):
  526. super().__init__()
  527. protein_repr_dim = config.protein_encoder_config.hidden_size
  528. self.num_latents = config.resampler_num_latents
  529. self.latents = nn.Parameter(torch.randn(self.num_latents, protein_repr_dim), requires_grad=True)
  530. self.layers = nn.ModuleList([])
  531. for _ in range(config.resampler_depth):
  532. self.layers.append(
  533. nn.ModuleList(
  534. [
  535. EvollaSequenceCompressorAttention(
  536. dim=protein_repr_dim, dim_head=config.resampler_dim_head, heads=config.resampler_heads
  537. ),
  538. EvollaFeedForward(dim=protein_repr_dim, mult=config.resampler_ff_mult),
  539. ]
  540. )
  541. )
  542. self.norm = nn.LayerNorm(config.hidden_size)
  543. self.protein_projector = nn.Linear(protein_repr_dim, config.hidden_size)
  544. def forward(self, embeds, mask):
  545. b = embeds.shape[0]
  546. bs, _ = mask.shape # bs, max_protein_length
  547. latent_mask = torch.ones(bs, self.num_latents).to(mask.device)
  548. mask = torch.cat((mask, latent_mask), dim=1) # bs, max_protein_length + num_latents
  549. # blocks
  550. ones = torch.ones(b).to(self.latents.device)
  551. latents = self.latents[None] * ones.view(-1, 1, 1) # [b,n,d]
  552. latents = latents.to(embeds.dtype)
  553. for attn, ff in self.layers:
  554. latents = attn(embeds, latents, mask) + latents
  555. latents = ff(latents) + latents
  556. transformed_feature = self.protein_projector(latents)
  557. return self.norm(transformed_feature)
  558. @dataclass
  559. @auto_docstring
  560. class EvollaProteinEncoderModelOutput(ModelOutput):
  561. sequence_compressor_output: torch.FloatTensor | None = None
  562. last_hidden_state: torch.FloatTensor | None = None
  563. hidden_states: tuple[torch.FloatTensor, ...] | None = None
  564. attentions: tuple[torch.FloatTensor, ...] | None = None
  565. class EvollaProteinEncoder(nn.Module):
  566. def __init__(self, config: EvollaConfig):
  567. super().__init__()
  568. self.model = EvollaSaProtProteinEncoder(config=config.protein_encoder_config)
  569. self.sequence_compressor_resampler = EvollaSequenceCompressorResampler(config=config)
  570. @can_return_tuple
  571. def forward(self, input_ids: torch.LongTensor, attention_mask: torch.FloatTensor, **kwargs):
  572. protein_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
  573. protein_embeds = protein_output.last_hidden_state
  574. sequence_repr = self.sequence_compressor_resampler(protein_embeds, attention_mask)
  575. return EvollaProteinEncoderModelOutput(
  576. sequence_compressor_output=sequence_repr,
  577. last_hidden_state=protein_output.last_hidden_state,
  578. )
  579. class EvollaSequenceAlignerCrossAttention(nn.Module):
  580. def __init__(
  581. self,
  582. config,
  583. protein_encoder_dim: int | None = None,
  584. structure_encoder_dim: int | None = None,
  585. msa_encoder_dim: int | None = None,
  586. ):
  587. super().__init__()
  588. self.hidden_size = config.hidden_size
  589. self.num_attention_heads = config.num_attention_heads
  590. self.scale = self.num_attention_heads**-0.5
  591. self.attention_head_size = int(self.hidden_size / self.num_attention_heads)
  592. self.all_head_size = self.num_attention_heads * self.attention_head_size
  593. attention_probs_dropout_prob = config.aligner_attention_probs_dropout_prob
  594. enable_bias = config.aligner_enable_bias
  595. ffn_mult = config.aligner_ffn_mult
  596. self.query = nn.Linear(self.hidden_size, self.all_head_size)
  597. if protein_encoder_dim is not None:
  598. self.key_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  599. self.value_protein = nn.Linear(protein_encoder_dim, self.all_head_size)
  600. else:
  601. self.key_protein = None
  602. self.value_protein = None
  603. if structure_encoder_dim is not None:
  604. self.key_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  605. self.value_structure = nn.Linear(structure_encoder_dim, self.all_head_size)
  606. else:
  607. self.key_structure = None
  608. self.value_structure = None
  609. if msa_encoder_dim is not None:
  610. self.key_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  611. self.value_msa = nn.Linear(msa_encoder_dim, self.all_head_size)
  612. else:
  613. self.key_msa = None
  614. self.value_msa = None
  615. self.attention_norm = EvollaRMSNorm(self.hidden_size)
  616. self.dropout = nn.Dropout(attention_probs_dropout_prob)
  617. self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=enable_bias)
  618. self.ff = EvollaFeedForward(self.hidden_size, ffn_mult)
  619. self.gate_attention = nn.Parameter(torch.tensor([0.0]))
  620. self.gate_ffw = nn.Parameter(torch.tensor([0.0]))
  621. def cross_attention(
  622. self,
  623. query_states,
  624. protein_key_value_states,
  625. structure_key_value_states,
  626. msa_key_value_states,
  627. query_attn_mask,
  628. protein_kv_attn_mask,
  629. structure_kv_attn_mask,
  630. msa_kv_attn_mask,
  631. ):
  632. """
  633. query_states: text
  634. key_value_states: protein
  635. query_states: [bs, query_seq_len, dim]
  636. key_value_states: [bs, kv_seq_len, dim]
  637. query_attn_mask: [bs, query_seq_len]
  638. kv_attn_mask: [bs, kv_seq_len]
  639. """
  640. # Concatenate protein and structure
  641. kv_attn_mask = [protein_kv_attn_mask, structure_kv_attn_mask, msa_kv_attn_mask]
  642. kv_attn_mask = [_ for _ in kv_attn_mask if _ is not None]
  643. if not kv_attn_mask:
  644. raise ValueError("At least one modality should be provided for cross attention.")
  645. kv_attn_mask = torch.cat(kv_attn_mask, dim=1)
  646. query_layer = self.attention_norm(query_states)
  647. # Warning: This place might cause issues, refers to
  648. # https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214/13
  649. # Solution: add `DISABLE_ADDMM_CUDA_LT=1` as environment variable
  650. # Apply linear transformation to input_query, input_key, and input_value
  651. query_layer = self.query(query_layer) # [bs, querylength, dim]
  652. if self.key_protein is not None and self.value_protein is not None:
  653. protein_key_value_states = protein_key_value_states.to(query_states)
  654. key_layer_protein = self.key_protein(protein_key_value_states) # [bs, keylength, dim]
  655. value_layer_protein = self.value_protein(protein_key_value_states) # [bs, keylength, dim]
  656. else:
  657. key_layer_protein = None
  658. value_layer_protein = None
  659. if self.key_structure is not None and self.value_structure is not None:
  660. structure_key_value_states = structure_key_value_states.to(query_states)
  661. key_layer_structure = self.key_structure(structure_key_value_states) # [bs, keylength, dim]
  662. value_layer_structure = self.value_structure(structure_key_value_states) # [bs, keylength, dim]
  663. else:
  664. key_layer_structure = None
  665. value_layer_structure = None
  666. if self.key_msa is not None and self.value_msa is not None:
  667. msa_key_value_states = msa_key_value_states.to(query_states)
  668. key_layer_msa = self.key_msa(msa_key_value_states) # [bs, keylength, dim]
  669. value_layer_msa = self.value_msa(msa_key_value_states) # [bs, keylength, dim]
  670. else:
  671. key_layer_msa = None
  672. value_layer_msa = None
  673. key_layer = [key_layer_protein, key_layer_structure, key_layer_msa]
  674. key_layer = [_ for _ in key_layer if _ is not None]
  675. key_layer = torch.cat(key_layer, dim=1)
  676. value_layer = [value_layer_protein, value_layer_structure, value_layer_msa]
  677. value_layer = [_ for _ in value_layer if _ is not None]
  678. value_layer = torch.cat(value_layer, dim=1)
  679. new_query_layer_shape = query_layer.size()[:-1] + (
  680. self.num_attention_heads,
  681. self.attention_head_size,
  682. )
  683. query_layer = query_layer.view(*new_query_layer_shape).permute(0, 2, 1, 3)
  684. new_key_layer_shape = key_layer.size()[:-1] + (
  685. self.num_attention_heads,
  686. self.attention_head_size,
  687. )
  688. key_layer = key_layer.view(*new_key_layer_shape).permute(0, 2, 1, 3)
  689. new_value_layer_shape = value_layer.size()[:-1] + (
  690. self.num_attention_heads,
  691. self.attention_head_size,
  692. )
  693. value_layer = value_layer.view(*new_value_layer_shape).permute(0, 2, 1, 3)
  694. query_layer = query_layer * self.scale
  695. # attention_mask: [bs, 1, querylength, keylength]
  696. if query_attn_mask is None:
  697. query_attn_mask = torch.ones(query_states.size(0), query_states.size(1)).to(query_states.device)
  698. attention_mask = query_attn_mask[:, None, :, None] * kv_attn_mask[:, None, None, :]
  699. # Compute the scaled dot-product attention scores
  700. attn_weights = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # [bs, numheads, querylength, keylength]
  701. attn_weights = attn_weights - attn_weights.amax(dim=-1, keepdim=True).detach() # To stabilize score
  702. attention_scores = attn_weights.masked_fill(
  703. (1 - attention_mask).bool(), torch.finfo(attn_weights.dtype).min
  704. ) # [bs, numheads, querylength, keylength]
  705. attention_probs = nn.Softmax(dim=-1)(attention_scores)
  706. # attention_probs_dropped = self.dropout(attention_probs)
  707. context_layer = torch.matmul(attention_probs, value_layer) # [bs, numheads, querylength, dim/numheads]
  708. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  709. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  710. context_layer = context_layer.view(*new_context_layer_shape)
  711. context_layer = self.out_proj(context_layer)
  712. return context_layer
  713. def forward(
  714. self,
  715. query_states,
  716. protein_kv_states,
  717. structure_kv_states,
  718. msa_kv_states,
  719. query_attn_mask,
  720. protein_kv_attn_mask=None,
  721. structure_kv_attn_mask=None,
  722. msa_kv_attn_mask=None,
  723. protein_batch_mask=None,
  724. structure_batch_mask=None,
  725. msa_batch_mask=None,
  726. past_key_values=None,
  727. ):
  728. if protein_kv_states is not None:
  729. bs, protein_kv_seq_len, dim = protein_kv_states.shape
  730. if protein_kv_attn_mask is None:
  731. protein_kv_attn_mask = (
  732. torch.ones(bs, protein_kv_seq_len).to(protein_batch_mask.device)
  733. * protein_batch_mask.expand(size=(protein_kv_seq_len, bs)).T
  734. ).to(protein_kv_states.device)
  735. else:
  736. protein_kv_attn_mask = None
  737. if structure_kv_states is not None:
  738. bs, structure_kv_seq_len, dim = structure_kv_states.shape
  739. if structure_kv_attn_mask is None:
  740. structure_kv_attn_mask = (
  741. torch.ones(bs, structure_kv_seq_len).to(protein_batch_mask.device)
  742. * structure_batch_mask.expand(size=(structure_kv_seq_len, bs)).T
  743. ).to(structure_kv_states.device)
  744. else:
  745. structure_kv_attn_mask = None
  746. if msa_kv_states is not None:
  747. bs, msa_kv_seq_len, dim = msa_kv_states.shape
  748. if msa_kv_attn_mask is None:
  749. msa_kv_attn_mask = (
  750. torch.ones(bs, msa_kv_seq_len).to(protein_batch_mask.device)
  751. * msa_batch_mask.expand(size=(msa_kv_seq_len, bs)).T
  752. ).to(msa_kv_states.device)
  753. else:
  754. msa_kv_attn_mask = None
  755. hidden_states = query_states
  756. # only when there's at least one valid modality, crossattention will be performed
  757. if (
  758. (protein_kv_states is not None and protein_kv_attn_mask.any())
  759. or (structure_kv_states is not None and structure_kv_attn_mask.any())
  760. or (msa_kv_states is not None and msa_kv_attn_mask.any())
  761. ):
  762. residual = hidden_states
  763. hidden_states = self.cross_attention(
  764. query_states=hidden_states,
  765. protein_key_value_states=protein_kv_states,
  766. structure_key_value_states=structure_kv_states,
  767. msa_key_value_states=msa_kv_states,
  768. query_attn_mask=query_attn_mask,
  769. protein_kv_attn_mask=protein_kv_attn_mask,
  770. structure_kv_attn_mask=structure_kv_attn_mask,
  771. msa_kv_attn_mask=msa_kv_attn_mask,
  772. ) # [bs, query_seq_len, dim]
  773. # tanh gate
  774. hidden_states = torch.tanh(self.gate_attention) * hidden_states
  775. hidden_states = residual + hidden_states # input_query
  776. residual = hidden_states
  777. hidden_states = self.ff(hidden_states) * torch.tanh(self.gate_ffw)
  778. hidden_states = residual + hidden_states
  779. return hidden_states
  780. @use_kernel_forward_from_hub("RMSNorm")
  781. class EvollaRMSNorm(nn.Module):
  782. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  783. """
  784. EvollaRMSNorm is equivalent to T5LayerNorm
  785. """
  786. super().__init__()
  787. self.weight = nn.Parameter(torch.ones(hidden_size))
  788. self.variance_epsilon = eps
  789. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  790. input_dtype = hidden_states.dtype
  791. hidden_states = hidden_states.to(torch.float32)
  792. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  793. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  794. return self.weight * hidden_states.to(input_dtype)
  795. def extra_repr(self):
  796. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  797. class EvollaRotaryEmbedding(nn.Module):
  798. inv_freq: torch.Tensor # fix linting for `register_buffer`
  799. def __init__(self, config: EvollaConfig, device=None):
  800. super().__init__()
  801. self.max_seq_len_cached = config.max_position_embeddings
  802. self.original_max_seq_len = config.max_position_embeddings
  803. self.config = config
  804. self.rope_type = self.config.rope_parameters["rope_type"]
  805. rope_init_fn: Callable = self.compute_default_rope_parameters
  806. if self.rope_type != "default":
  807. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  808. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  809. self.register_buffer("inv_freq", inv_freq, persistent=False)
  810. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  811. @staticmethod
  812. def compute_default_rope_parameters(
  813. config: EvollaConfig | None = None,
  814. device: Optional["torch.device"] = None,
  815. seq_len: int | None = None,
  816. ) -> tuple["torch.Tensor", float]:
  817. """
  818. Computes the inverse frequencies according to the original RoPE implementation
  819. Args:
  820. config ([`~transformers.PreTrainedConfig`]):
  821. The model configuration.
  822. device (`torch.device`):
  823. The device to use for initialization of the inverse frequencies.
  824. seq_len (`int`, *optional*):
  825. The current sequence length. Unused for this type of RoPE.
  826. Returns:
  827. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  828. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  829. """
  830. base = config.rope_parameters["rope_theta"]
  831. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  832. attention_factor = 1.0 # Unused in this type of RoPE
  833. # Compute the inverse frequencies
  834. inv_freq = 1.0 / (
  835. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  836. )
  837. return inv_freq, attention_factor
  838. @torch.no_grad()
  839. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  840. def forward(self, x, position_ids):
  841. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  842. position_ids_expanded = position_ids[:, None, :].float()
  843. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  844. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  845. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  846. emb = torch.cat((freqs, freqs), dim=-1)
  847. cos = emb.cos() * self.attention_scaling
  848. sin = emb.sin() * self.attention_scaling
  849. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  850. class EvollaMLP(nn.Module):
  851. def __init__(self, config):
  852. super().__init__()
  853. self.config = config
  854. self.hidden_size = config.hidden_size
  855. self.intermediate_size = config.intermediate_size
  856. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  857. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
  858. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
  859. self.act_fn = ACT2FN[config.hidden_act]
  860. def forward(self, x):
  861. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  862. return down_proj
  863. def rotate_half(x):
  864. """Rotates half the hidden dims of the input."""
  865. x1 = x[..., : x.shape[-1] // 2]
  866. x2 = x[..., x.shape[-1] // 2 :]
  867. return torch.cat((-x2, x1), dim=-1)
  868. @use_kernel_func_from_hub("rotary_pos_emb")
  869. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  870. """Applies Rotary Position Embedding to the query and key tensors.
  871. Args:
  872. q (`torch.Tensor`): The query tensor.
  873. k (`torch.Tensor`): The key tensor.
  874. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  875. sin (`torch.Tensor`): The sine part of the rotary embedding.
  876. unsqueeze_dim (`int`, *optional*, defaults to 1):
  877. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  878. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  879. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  880. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  881. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  882. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  883. Returns:
  884. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  885. """
  886. cos = cos.unsqueeze(unsqueeze_dim)
  887. sin = sin.unsqueeze(unsqueeze_dim)
  888. q_embed = (q * cos) + (rotate_half(q) * sin)
  889. k_embed = (k * cos) + (rotate_half(k) * sin)
  890. return q_embed, k_embed
  891. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  892. """
  893. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  894. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  895. """
  896. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  897. if n_rep == 1:
  898. return hidden_states
  899. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  900. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  901. @use_kernelized_func(apply_rotary_pos_emb)
  902. class EvollaAttention(nn.Module):
  903. """Multi-headed attention from 'Attention Is All You Need' paper"""
  904. def __init__(self, config: EvollaConfig, layer_idx: int):
  905. super().__init__()
  906. self.config = config
  907. self.layer_idx = layer_idx
  908. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  909. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  910. self.scaling = self.head_dim**-0.5
  911. self.attention_dropout = config.attention_dropout
  912. self.is_causal = True
  913. self.q_proj = nn.Linear(
  914. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  915. )
  916. self.k_proj = nn.Linear(
  917. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  918. )
  919. self.v_proj = nn.Linear(
  920. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  921. )
  922. self.o_proj = nn.Linear(
  923. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  924. )
  925. def forward(
  926. self,
  927. hidden_states: torch.Tensor,
  928. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  929. attention_mask: torch.Tensor | None = None,
  930. past_key_values: Cache | None = None,
  931. **kwargs: Unpack[TransformersKwargs],
  932. ) -> tuple[torch.Tensor, torch.Tensor]:
  933. input_shape = hidden_states.shape[:-1]
  934. hidden_shape = (*input_shape, -1, self.head_dim)
  935. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  936. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  937. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  938. cos, sin = position_embeddings
  939. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  940. if past_key_values is not None:
  941. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  942. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  943. self.config._attn_implementation, eager_attention_forward
  944. )
  945. attn_output, attn_weights = attention_interface(
  946. self,
  947. query_states,
  948. key_states,
  949. value_states,
  950. attention_mask,
  951. dropout=0.0 if not self.training else self.attention_dropout,
  952. scaling=self.scaling,
  953. **kwargs,
  954. )
  955. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  956. attn_output = self.o_proj(attn_output)
  957. return attn_output, attn_weights
  958. class EvollaDecoderLayer(GradientCheckpointingLayer):
  959. def __init__(self, config: EvollaConfig, layer_idx: int):
  960. super().__init__()
  961. self.hidden_size = config.hidden_size
  962. self.self_attn = EvollaAttention(config=config, layer_idx=layer_idx)
  963. self.mlp = EvollaMLP(config)
  964. self.input_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  965. self.post_attention_layernorm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  966. if (layer_idx + 1) % max(config.num_hidden_layers // config.aligner_num_add_layers, 1) == 0:
  967. self.adapter = EvollaSequenceAlignerCrossAttention(
  968. config,
  969. protein_encoder_dim=config.hidden_size,
  970. )
  971. def forward(
  972. self,
  973. hidden_states: torch.Tensor,
  974. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  975. attention_mask: torch.Tensor | None = None,
  976. position_ids: torch.LongTensor | None = None,
  977. past_key_values: Cache | None = None,
  978. use_cache: bool | None = False,
  979. protein_kv_states: torch.Tensor | None = None,
  980. structure_kv_states: torch.Tensor | None = None,
  981. msa_kv_states: torch.Tensor | None = None,
  982. protein_batch_mask: torch.Tensor | None = None,
  983. structure_batch_mask: torch.Tensor | None = None,
  984. msa_batch_mask: torch.Tensor | None = None,
  985. query_attn_mask: torch.Tensor | None = None,
  986. **kwargs,
  987. ) -> torch.Tensor:
  988. residual = hidden_states
  989. hidden_states = self.input_layernorm(hidden_states)
  990. # Self Attention
  991. hidden_states, _ = self.self_attn(
  992. hidden_states=hidden_states,
  993. attention_mask=attention_mask,
  994. position_ids=position_ids,
  995. past_key_values=past_key_values,
  996. use_cache=use_cache,
  997. position_embeddings=position_embeddings,
  998. **kwargs,
  999. )
  1000. hidden_states = residual + hidden_states
  1001. # Fully Connected
  1002. residual = hidden_states
  1003. hidden_states = self.post_attention_layernorm(hidden_states)
  1004. hidden_states = self.mlp(hidden_states)
  1005. hidden_states = residual + hidden_states
  1006. if hasattr(self, "adapter"):
  1007. hidden_states = self.adapter(
  1008. query_states=hidden_states,
  1009. protein_kv_states=protein_kv_states,
  1010. structure_kv_states=structure_kv_states,
  1011. msa_kv_states=msa_kv_states,
  1012. query_attn_mask=query_attn_mask,
  1013. protein_batch_mask=protein_batch_mask,
  1014. structure_batch_mask=structure_batch_mask,
  1015. msa_batch_mask=msa_batch_mask,
  1016. )
  1017. return hidden_states
  1018. @auto_docstring
  1019. class EvollaPreTrainedModel(PreTrainedModel):
  1020. config: EvollaConfig
  1021. base_model_prefix = "model"
  1022. supports_gradient_checkpointing = True
  1023. _no_split_modules = [
  1024. "EvollaDecoderLayer",
  1025. "EvollaSequenceCompressorResampler",
  1026. "EvollaSequenceAlignerCrossAttention",
  1027. ]
  1028. _skip_keys_device_placement = ["past_key_values"]
  1029. _supports_flash_attn = False # see dependency on `EvollaSequenceCompressorResampler`
  1030. _supports_sdpa = True
  1031. _supports_flex_attn = False # see dependency on `EvollaSequenceCompressorResampler`
  1032. _can_compile_fullgraph = True
  1033. _supports_attention_backend = False
  1034. _can_record_outputs = {
  1035. "hidden_states": EvollaDecoderLayer,
  1036. "attentions": EvollaAttention,
  1037. }
  1038. @torch.no_grad()
  1039. def _init_weights(self, module):
  1040. std = self.config.initializer_range
  1041. super()._init_weights(module)
  1042. if isinstance(module, EvollaSequenceAlignerCrossAttention):
  1043. init.zeros_(module.gate_attention)
  1044. init.zeros_(module.gate_ffw)
  1045. init.ones_(module.attention_norm.weight)
  1046. elif isinstance(module, EvollaSequenceCompressorResampler):
  1047. init.normal_(module.latents, mean=0.0, std=std)
  1048. class EvollaModel(EvollaPreTrainedModel):
  1049. def __init__(self, config: EvollaConfig):
  1050. super().__init__(config)
  1051. self.padding_idx = config.pad_token_id
  1052. self.vocab_size = config.vocab_size
  1053. self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
  1054. self.protein_encoder = EvollaProteinEncoder(config=config)
  1055. self.layers = nn.ModuleList(
  1056. [
  1057. EvollaDecoderLayer(
  1058. config=config,
  1059. layer_idx=layer_idx,
  1060. )
  1061. for layer_idx in range(config.num_hidden_layers)
  1062. ]
  1063. )
  1064. self.norm = EvollaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  1065. self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False)
  1066. self.rotary_emb = EvollaRotaryEmbedding(config=config)
  1067. self.post_init()
  1068. def get_input_embeddings(self):
  1069. return self.embed_tokens
  1070. def set_input_embeddings(self, value):
  1071. self.embed_tokens = value
  1072. @auto_docstring
  1073. @merge_with_config_defaults
  1074. @capture_outputs
  1075. def forward(
  1076. self,
  1077. input_ids: torch.LongTensor | None = None,
  1078. attention_mask: torch.Tensor | None = None,
  1079. position_ids: torch.LongTensor | None = None,
  1080. past_key_values: Cache | None = None,
  1081. inputs_embeds: torch.FloatTensor | None = None,
  1082. use_cache: bool | None = None,
  1083. protein_input_ids: torch.LongTensor | None = None,
  1084. protein_attention_mask: torch.Tensor | None = None,
  1085. structure_feats: torch.FloatTensor | None = None,
  1086. msa_feats: torch.FloatTensor | None = None,
  1087. structure_batch_mask: torch.Tensor | None = None,
  1088. msa_batch_mask: torch.Tensor | None = None,
  1089. **kwargs,
  1090. ) -> tuple | BaseModelOutputWithPast:
  1091. r"""
  1092. protein_input_ids (torch.LongTensor):
  1093. The input IDs for the protein sequence in structure-aware tokens. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  1094. protein_attention_mask (torch.Tensor):
  1095. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  1096. structure_feats (torch.FloatTensor):
  1097. The input IDs for purely structure-based features. Should be of shape `(batch_size, structure_seq_length, structure_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  1098. msa_feats (torch.FloatTensor):
  1099. The input IDs for purely MSA-based features. Should be of shape `(batch_size, msa_seq_length, msa_feat_dim)` and type `torch.FloatTensor`. Dummy input for now.
  1100. structure_batch_mask (torch.Tensor):
  1101. The batch mask to decide which protein sequences are purely structure-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `structure_feats`. Dummpy input for now.
  1102. msa_batch_mask (torch.Tensor):
  1103. The batch mask to decide which protein sequences are purely MSA-based. Should be of shape `(batch_size)` and type `torch.Tensor`. Should be paired with `msa_feats`. Dummpy input for now.
  1104. """
  1105. if (input_ids is None) ^ (inputs_embeds is not None):
  1106. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1107. if inputs_embeds is None:
  1108. inputs_embeds = self.embed_tokens(input_ids)
  1109. if use_cache and past_key_values is None:
  1110. past_key_values = DynamicCache(config=self.config)
  1111. if position_ids is None:
  1112. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1113. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  1114. position_ids = position_ids.unsqueeze(0)
  1115. protein_feats = None
  1116. protein_batch_mask = None
  1117. # If provided, actually compute them
  1118. if protein_input_ids is not None and protein_attention_mask is not None:
  1119. protein_outputs = self.protein_encoder(
  1120. input_ids=protein_input_ids,
  1121. attention_mask=protein_attention_mask,
  1122. )
  1123. protein_feats = protein_outputs.sequence_compressor_output
  1124. protein_batch_mask = torch.ones(
  1125. protein_input_ids.shape[0],
  1126. device=protein_input_ids.device,
  1127. dtype=torch.bool,
  1128. )
  1129. causal_mask = create_causal_mask(
  1130. config=self.config,
  1131. inputs_embeds=inputs_embeds,
  1132. attention_mask=attention_mask,
  1133. past_key_values=past_key_values,
  1134. )
  1135. hidden_states = inputs_embeds
  1136. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  1137. for decoder_layer in self.layers:
  1138. hidden_states = decoder_layer(
  1139. hidden_states,
  1140. attention_mask=causal_mask,
  1141. position_ids=position_ids,
  1142. past_key_values=past_key_values,
  1143. use_cache=use_cache,
  1144. protein_kv_states=protein_feats,
  1145. structure_kv_states=structure_feats,
  1146. msa_kv_states=msa_feats,
  1147. protein_batch_mask=protein_batch_mask,
  1148. structure_batch_mask=structure_batch_mask,
  1149. msa_batch_mask=msa_batch_mask,
  1150. query_attn_mask=attention_mask,
  1151. position_embeddings=position_embeddings,
  1152. **kwargs,
  1153. )
  1154. hidden_states = self.norm(hidden_states)
  1155. output = BaseModelOutputWithPast(
  1156. last_hidden_state=hidden_states,
  1157. past_key_values=past_key_values,
  1158. )
  1159. return output
  1160. class EvollaForProteinText2Text(EvollaPreTrainedModel, GenerationMixin):
  1161. def __init__(self, config):
  1162. super().__init__(config)
  1163. self.model = EvollaModel(config)
  1164. self.vocab_size = config.vocab_size
  1165. self.lm_head = nn.Linear(config.hidden_size, self.vocab_size, bias=False)
  1166. self.post_init()
  1167. def get_input_embeddings(self):
  1168. return self.model.get_input_embeddings()
  1169. def set_input_embeddings(self, value):
  1170. return self.model.set_input_embeddings(value)
  1171. @can_return_tuple
  1172. @auto_docstring
  1173. def forward(
  1174. self,
  1175. input_ids: torch.LongTensor | None = None, # text input ids
  1176. attention_mask: torch.Tensor | None = None, # text attention mask
  1177. inputs_embeds: torch.FloatTensor | None = None, # text input embeddings
  1178. labels: torch.LongTensor | None = None,
  1179. protein_input_ids: torch.LongTensor | None = None,
  1180. protein_attention_mask: torch.Tensor | None = None,
  1181. use_cache: bool | None = None,
  1182. logits_to_keep: int | torch.Tensor = 0,
  1183. **kwargs,
  1184. ):
  1185. r"""
  1186. protein_input_ids (torch.LongTensor):
  1187. The input IDs for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.LongTensor`.
  1188. protein_attention_mask (torch.Tensor):
  1189. The attention mask for the protein sequence. Should be of shape `(batch_size, protein_seq_length)` and type `torch.Tensor`.
  1190. Example:
  1191. ```python
  1192. >>> from transformers import EvollaProcessor, EvollaForProteinText2Text
  1193. >>> model = EvollaForProteinText2Text.from_pretrained("westlake/Evolla-10B-hf")
  1194. >>> processor = EvollaProcessor.from_pretrained("westlake/Evolla-10B-hf")
  1195. >>> protein_information = {
  1196. "aa_seq": "your amino acid sequence",
  1197. "foldseek": "your foldseek sequence",
  1198. }
  1199. >>> question = "What is the function of this protein?"
  1200. >>> message = [
  1201. {"role": "system", "content": "You are an AI expert that can answer any questions about protein."},
  1202. {"role": "user", "content": question},
  1203. ]
  1204. >>> inputs = processor(proteins=[protein_information], messages_list=[message], return_tensors="pt", padding="longest")
  1205. >>> outputs = model.generate(**inputs)
  1206. >>> print(processor.batch_decode(outputs, skip_special_tokens=True))
  1207. ```"""
  1208. outputs: BaseModelOutputWithPast = self.model(
  1209. input_ids=input_ids,
  1210. attention_mask=attention_mask,
  1211. inputs_embeds=inputs_embeds,
  1212. protein_input_ids=protein_input_ids,
  1213. protein_attention_mask=protein_attention_mask,
  1214. use_cache=use_cache,
  1215. **kwargs,
  1216. )
  1217. hidden_states = outputs.last_hidden_state
  1218. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  1219. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1220. logits = self.lm_head(hidden_states[:, slice_indices, :])
  1221. loss = None
  1222. if labels is not None:
  1223. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  1224. lm_outputs = CausalLMOutputWithPast(
  1225. loss=loss,
  1226. logits=logits,
  1227. past_key_values=outputs.past_key_values,
  1228. hidden_states=outputs.hidden_states,
  1229. attentions=outputs.attentions,
  1230. )
  1231. return lm_outputs
  1232. __all__ = ["EvollaForProteinText2Text", "EvollaModel", "EvollaPreTrainedModel"]