configuration_esm.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """ESM model configuration"""
  15. from typing import Union
  16. from huggingface_hub.dataclasses import strict
  17. from ...configuration_utils import PreTrainedConfig
  18. from ...utils import auto_docstring, logging
  19. from ...utils.type_validators import interval, is_divisible_by
  20. logger = logging.get_logger(__name__)
  21. @strict
  22. class StructureModuleConfig(PreTrainedConfig):
  23. """
  24. Args:
  25. sequence_dim:
  26. Single representation channel dimension
  27. pairwise_dim:
  28. Pair representation channel dimension
  29. ipa_dim:
  30. IPA hidden channel dimension
  31. resnet_dim:
  32. Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
  33. num_heads_ipa:
  34. Number of IPA heads
  35. num_qk_points:
  36. Number of query/key points to generate during IPA
  37. num_v_points:
  38. Number of value points to generate during IPA
  39. dropout_rate:
  40. Dropout rate used throughout the layer
  41. num_blocks:
  42. Number of structure module blocks
  43. num_transition_layers:
  44. Number of layers in the single representation transition (Alg. 23 lines 8-9)
  45. num_resnet_blocks:
  46. Number of blocks in the angle resnet
  47. num_angles:
  48. Number of angles to generate in the angle resnet
  49. trans_scale_factor:
  50. Scale of single representation transition hidden dimension
  51. epsilon:
  52. Small number used in angle resnet normalization
  53. inf:
  54. Large number used for attention masking
  55. """
  56. sequence_dim: int | None = 384
  57. pairwise_dim: int | None = 128
  58. ipa_dim: int | None = 16
  59. resnet_dim: int | None = 128
  60. num_heads_ipa: int | None = 12
  61. num_qk_points: int | None = 4
  62. num_v_points: int | None = 8
  63. dropout_rate: float | None = 0.1
  64. num_blocks: int | None = 8
  65. num_transition_layers: int | None = 1
  66. num_resnet_blocks: int | None = 2
  67. num_angles: int | None = 7
  68. trans_scale_factor: int | None = 10
  69. epsilon: float | None = 1e-8
  70. inf: float | None = 1e5
  71. @strict
  72. class TrunkConfig(PreTrainedConfig):
  73. sub_configs = {"structure_module": StructureModuleConfig}
  74. num_blocks: int | None = 48
  75. sequence_state_dim: int | None = 1024
  76. pairwise_state_dim: int | None = is_divisible_by(divisor=2)(default=128)
  77. sequence_head_width: int | None = 32
  78. pairwise_head_width: int | None = 32
  79. position_bins: int | None = 32
  80. dropout: float | int | None = interval(max=0.4)(default=0.0)
  81. layer_drop: float | int | None = 0.0
  82. cpu_grad_checkpoint: bool | None = False
  83. max_recycles: int | None = interval(min=0)(default=4)
  84. chunk_size: int | None = 128
  85. structure_module: Union[dict, "StructureModuleConfig"] | None = None
  86. def __post_init__(self, **kwargs):
  87. if self.structure_module is None:
  88. self.structure_module = StructureModuleConfig()
  89. elif isinstance(self.structure_module, dict):
  90. self.structure_module = StructureModuleConfig(**self.structure_module)
  91. super().__post_init__(**kwargs)
  92. def validate_architecture(self):
  93. if self.sequence_state_dim % self.sequence_state_dim != 0:
  94. raise ValueError(
  95. "`sequence_state_dim` should be a round multiple of `sequence_state_dim`, got"
  96. f" {self.sequence_state_dim} and {self.sequence_state_dim}."
  97. )
  98. if self.pairwise_state_dim % self.pairwise_state_dim != 0:
  99. raise ValueError(
  100. "`pairwise_state_dim` should be a round multiple of `pairwise_state_dim`, got"
  101. f" {self.pairwise_state_dim} and {self.pairwise_state_dim}."
  102. )
  103. sequence_num_heads = self.sequence_state_dim // self.sequence_head_width
  104. pairwise_num_heads = self.pairwise_state_dim // self.pairwise_head_width
  105. if self.sequence_state_dim != sequence_num_heads * self.sequence_head_width:
  106. raise ValueError(
  107. "`sequence_state_dim` should be equal to `sequence_num_heads * sequence_head_width, got"
  108. f" {self.sequence_state_dim} != {sequence_num_heads} * {self.sequence_head_width}."
  109. )
  110. if self.pairwise_state_dim != pairwise_num_heads * self.pairwise_head_width:
  111. raise ValueError(
  112. "`pairwise_state_dim` should be equal to `pairwise_num_heads * pairwise_head_width, got"
  113. f" {self.pairwise_state_dim} != {pairwise_num_heads} * {self.pairwise_head_width}."
  114. )
  115. @strict
  116. class EsmFoldConfig(PreTrainedConfig):
  117. sub_configs = {"trunk": TrunkConfig}
  118. esm_type: str | None = None
  119. fp16_esm: bool | None = True
  120. use_esm_attn_map: bool | None = False
  121. esm_ablate_pairwise: bool | None = False
  122. esm_ablate_sequence: bool | None = False
  123. esm_input_dropout: float | int | None = 0.0
  124. embed_aa: bool | None = True
  125. bypass_lm: bool | None = False
  126. lddt_head_hid_dim: int | None = 128
  127. trunk: Union[dict, "TrunkConfig"] | None = None
  128. def __post_init__(self, **kwargs):
  129. if self.trunk is None:
  130. self.trunk = TrunkConfig()
  131. elif isinstance(self.trunk, dict):
  132. self.trunk = TrunkConfig(**self.trunk)
  133. super().__post_init__(**kwargs)
  134. @auto_docstring(checkpoint="facebook/esm-1b")
  135. @strict
  136. class EsmConfig(PreTrainedConfig):
  137. r"""
  138. mask_token_id (`int`, *optional*):
  139. The index of the mask token in the vocabulary. This must be included in the config because of the
  140. "mask-dropout" scaling trick, which will scale the inputs depending on the number of masked tokens.
  141. position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
  142. Type of position embedding. Choose either `"absolute"` or "rotary"`.
  143. emb_layer_norm_before (`bool`, *optional*):
  144. Whether to apply layer normalization after embeddings but before the main stem of the network.
  145. token_dropout (`bool`, defaults to `False`):
  146. When this is enabled, masked tokens are treated as if they had been dropped out by input dropout.
  147. is_folding_model (`bool`, defaults to `False`):
  148. When this is enabled, ESMFold model will be initialized.
  149. esmfold_config (`dict`, *optional*):
  150. Configuration to initiate the ESMFold module.
  151. vocab_list (`list`, *optional*):
  152. List of the vocabulary items.
  153. Examples:
  154. ```python
  155. >>> from transformers import EsmModel, EsmConfig
  156. >>> # Initializing a ESM facebook/esm-1b style configuration
  157. >>> configuration = EsmConfig(vocab_size=33)
  158. >>> # Initializing a model from the configuration
  159. >>> model = EsmModel(configuration)
  160. >>> # Accessing the model configuration
  161. >>> configuration = model.config
  162. ```"""
  163. model_type = "esm"
  164. sub_configs = {"esmfold_config": EsmFoldConfig}
  165. vocab_size: int | None = None
  166. mask_token_id: int | None = None
  167. pad_token_id: int | None = None
  168. hidden_size: int = 768
  169. num_hidden_layers: int = 12
  170. num_attention_heads: int = 12
  171. intermediate_size: int = 3072
  172. hidden_dropout_prob: float | None = 0.1
  173. attention_probs_dropout_prob: float | None = 0.1
  174. max_position_embeddings: int = 1026
  175. initializer_range: float = 0.02
  176. layer_norm_eps: float | None = 1e-12
  177. position_embedding_type: str | None = "absolute"
  178. use_cache: bool = True
  179. emb_layer_norm_before: bool | None = None
  180. token_dropout: bool | None = False
  181. is_folding_model: bool | None = False
  182. esmfold_config: dict | EsmFoldConfig | None = None
  183. vocab_list: list[str] | tuple[str, ...] | None = None
  184. is_decoder: bool | None = False
  185. add_cross_attention: bool | None = False
  186. tie_word_embeddings: bool = True
  187. def __post_init__(self, **kwargs):
  188. if self.is_folding_model:
  189. if self.esmfold_config is None:
  190. logger.info("No esmfold_config supplied for folding model, using default values.")
  191. self.esmfold_config = EsmFoldConfig()
  192. elif isinstance(self.esmfold_config, dict):
  193. self.esmfold_config = EsmFoldConfig(**self.esmfold_config)
  194. if self.vocab_list is None:
  195. logger.warning("No vocab_list supplied for folding model, assuming the ESM-2 vocabulary!")
  196. self.vocab_list = get_default_vocab_list()
  197. else:
  198. self.esmfold_config = None
  199. self.vocab_list = None
  200. if self.esmfold_config is not None and getattr(self.esmfold_config, "use_esm_attn_map", False):
  201. raise ValueError("The HuggingFace port of ESMFold does not support use_esm_attn_map at this time!")
  202. super().__post_init__(**kwargs)
  203. def get_default_vocab_list():
  204. return (
  205. "<cls>",
  206. "<pad>",
  207. "<eos>",
  208. "<unk>",
  209. "L",
  210. "A",
  211. "G",
  212. "V",
  213. "S",
  214. "E",
  215. "R",
  216. "T",
  217. "I",
  218. "D",
  219. "P",
  220. "K",
  221. "Q",
  222. "N",
  223. "F",
  224. "Y",
  225. "M",
  226. "H",
  227. "W",
  228. "C",
  229. "X",
  230. "B",
  231. "U",
  232. "Z",
  233. "O",
  234. ".",
  235. "-",
  236. "<null_1>",
  237. "<mask>",
  238. )
  239. __all__ = ["EsmConfig"]