| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """ESM model configuration"""
- from typing import Union
- from huggingface_hub.dataclasses import strict
- from ...configuration_utils import PreTrainedConfig
- from ...utils import auto_docstring, logging
- from ...utils.type_validators import interval, is_divisible_by
- logger = logging.get_logger(__name__)
- @strict
- class StructureModuleConfig(PreTrainedConfig):
- """
- Args:
- sequence_dim:
- Single representation channel dimension
- pairwise_dim:
- Pair representation channel dimension
- ipa_dim:
- IPA hidden channel dimension
- resnet_dim:
- Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
- num_heads_ipa:
- Number of IPA heads
- num_qk_points:
- Number of query/key points to generate during IPA
- num_v_points:
- Number of value points to generate during IPA
- dropout_rate:
- Dropout rate used throughout the layer
- num_blocks:
- Number of structure module blocks
- num_transition_layers:
- Number of layers in the single representation transition (Alg. 23 lines 8-9)
- num_resnet_blocks:
- Number of blocks in the angle resnet
- num_angles:
- Number of angles to generate in the angle resnet
- trans_scale_factor:
- Scale of single representation transition hidden dimension
- epsilon:
- Small number used in angle resnet normalization
- inf:
- Large number used for attention masking
- """
- sequence_dim: int | None = 384
- pairwise_dim: int | None = 128
- ipa_dim: int | None = 16
- resnet_dim: int | None = 128
- num_heads_ipa: int | None = 12
- num_qk_points: int | None = 4
- num_v_points: int | None = 8
- dropout_rate: float | None = 0.1
- num_blocks: int | None = 8
- num_transition_layers: int | None = 1
- num_resnet_blocks: int | None = 2
- num_angles: int | None = 7
- trans_scale_factor: int | None = 10
- epsilon: float | None = 1e-8
- inf: float | None = 1e5
- @strict
- class TrunkConfig(PreTrainedConfig):
- sub_configs = {"structure_module": StructureModuleConfig}
- num_blocks: int | None = 48
- sequence_state_dim: int | None = 1024
- pairwise_state_dim: int | None = is_divisible_by(divisor=2)(default=128)
- sequence_head_width: int | None = 32
- pairwise_head_width: int | None = 32
- position_bins: int | None = 32
- dropout: float | int | None = interval(max=0.4)(default=0.0)
- layer_drop: float | int | None = 0.0
- cpu_grad_checkpoint: bool | None = False
- max_recycles: int | None = interval(min=0)(default=4)
- chunk_size: int | None = 128
- structure_module: Union[dict, "StructureModuleConfig"] | None = None
- def __post_init__(self, **kwargs):
- if self.structure_module is None:
- self.structure_module = StructureModuleConfig()
- elif isinstance(self.structure_module, dict):
- self.structure_module = StructureModuleConfig(**self.structure_module)
- super().__post_init__(**kwargs)
- def validate_architecture(self):
- if self.sequence_state_dim % self.sequence_state_dim != 0:
- raise ValueError(
- "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
- f" {self.sequence_state_dim} and {self.sequence_state_dim}."
- )
- if self.pairwise_state_dim % self.pairwise_state_dim != 0:
- raise ValueError(
- "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
- f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
- )
- sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
- pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
- if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
- raise ValueError(
- "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
- f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
- )
- if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
- raise ValueError(
- "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
- f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
- )
- @strict
- class EsmFoldConfig(PreTrainedConfig):
- sub_configs = {"trunk": TrunkConfig}
- esm_type: str | None = None
- fp16_esm: bool | None = True
- use_esm_attn_map: bool | None = False
- esm_ablate_pairwise: bool | None = False
- esm_ablate_sequence: bool | None = False
- esm_input_dropout: float | int | None = 0.0
- embed_aa: bool | None = True
- bypass_lm: bool | None = False
- lddt_head_hid_dim: int | None = 128
- trunk: Union[dict, "TrunkConfig"] | None = None
- def __post_init__(self, **kwargs):
- if self.trunk is None:
- self.trunk = TrunkConfig()
- elif isinstance(self.trunk, dict):
- self.trunk = TrunkConfig(**self.trunk)
- super().__post_init__(**kwargs)
- @auto_docstring(checkpoint="facebook/esm-1b")
- @strict
- class EsmConfig(PreTrainedConfig):
- r"""
- mask_token_id (`int`, *optional*):
- The index of the mask token in the vocabulary. This must be included in the config because of the
- "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
- position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
- Type of position embedding. Choose either `"absolute"` or "rotary"`.
- emb_layer_norm_before (`bool`, *optional*):
- Whether to apply layer normalization after embeddings but before the main stem of the network.
- token_dropout (`bool`, defaults to `False`):
- When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
- is_folding_model (`bool`, defaults to `False`):
- When this is enabled, ESMFold model will be initialized.
- esmfold_config (`dict`, *optional*):
- Configuration to initiate the ESMFold module.
- vocab_list (`list`, *optional*):
- List of the vocabulary items.
- Examples:
- ```python
- >>> from transformers import EsmModel, EsmConfig
- >>> # Initializing a ESM facebook/esm-1b style configuration
- >>> configuration = EsmConfig(vocab_size=33)
- >>> # Initializing a model from the configuration
- >>> model = EsmModel(configuration)
- >>> # Accessing the model configuration
- >>> configuration = model.config
- ```"""
- model_type = "esm"
- sub_configs = {"esmfold_config": EsmFoldConfig}
- vocab_size: int | None = None
- mask_token_id: int | None = None
- pad_token_id: int | None = None
- hidden_size: int = 768
- num_hidden_layers: int = 12
- num_attention_heads: int = 12
- intermediate_size: int = 3072
- hidden_dropout_prob: float | None = 0.1
- attention_probs_dropout_prob: float | None = 0.1
- max_position_embeddings: int = 1026
- initializer_range: float = 0.02
- layer_norm_eps: float | None = 1e-12
- position_embedding_type: str | None = "absolute"
- use_cache: bool = True
- emb_layer_norm_before: bool | None = None
- token_dropout: bool | None = False
- is_folding_model: bool | None = False
- esmfold_config: dict | EsmFoldConfig | None = None
- vocab_list: list[str] | tuple[str, ...] | None = None
- is_decoder: bool | None = False
- add_cross_attention: bool | None = False
- tie_word_embeddings: bool = True
- def __post_init__(self, **kwargs):
- if self.is_folding_model:
- if self.esmfold_config is None:
- logger.info("No esmfold_config supplied for folding model, using default values.")
- self.esmfold_config = EsmFoldConfig()
- elif isinstance(self.esmfold_config, dict):
- self.esmfold_config = EsmFoldConfig(**self.esmfold_config)
- if self.vocab_list is None:
- logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
- self.vocab_list = get_default_vocab_list()
- else:
- self.esmfold_config = None
- self.vocab_list = None
- if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
- raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
- super().__post_init__(**kwargs)
- def get_default_vocab_list():
- return (
- "<cls>",
- "<pad>",
- "<eos>",
- "<unk>",
- "L",
- "A",
- "G",
- "V",
- "S",
- "E",
- "R",
- "T",
- "I",
- "D",
- "P",
- "K",
- "Q",
- "N",
- "F",
- "Y",
- "M",
- "H",
- "W",
- "C",
- "X",
- "B",
- "U",
- "Z",
- "O",
- ".",
- "-",
- "<null_1>",
- "<mask>",
- )
- __all__ = ["EsmConfig"]
|