tokenization_esm.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # Copyright 2022 Meta and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tokenization classes for ESM."""
  15. import os
  16. from ...tokenization_python import PreTrainedTokenizer
  17. from ...utils import logging
  18. logger = logging.get_logger(__name__)
  19. VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
  20. def load_vocab_file(vocab_file):
  21. with open(vocab_file, "r") as f:
  22. lines = f.read().splitlines()
  23. return [l.strip() for l in lines]
  24. class EsmTokenizer(PreTrainedTokenizer):
  25. """
  26. Constructs an ESM tokenizer.
  27. """
  28. vocab_files_names = VOCAB_FILES_NAMES
  29. model_input_names = ["input_ids", "attention_mask"]
  30. def __init__(
  31. self,
  32. vocab_file,
  33. unk_token="<unk>",
  34. cls_token="<cls>",
  35. pad_token="<pad>",
  36. mask_token="<mask>",
  37. eos_token="<eos>",
  38. **kwargs,
  39. ):
  40. self.all_tokens = load_vocab_file(vocab_file)
  41. self._id_to_token = dict(enumerate(self.all_tokens))
  42. self._token_to_id = {tok: ind for ind, tok in enumerate(self.all_tokens)}
  43. super().__init__(
  44. unk_token=unk_token,
  45. cls_token=cls_token,
  46. pad_token=pad_token,
  47. mask_token=mask_token,
  48. eos_token=eos_token,
  49. **kwargs,
  50. )
  51. # TODO, all the tokens are added? But they are also part of the vocab... bit strange.
  52. # none of them are special, but they all need special splitting.
  53. self.unique_no_split_tokens = self.all_tokens
  54. self._update_trie(self.unique_no_split_tokens)
  55. def _convert_id_to_token(self, index: int) -> str:
  56. return self._id_to_token.get(index, self.unk_token)
  57. def _convert_token_to_id(self, token: str) -> int:
  58. return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
  59. def _tokenize(self, text, **kwargs):
  60. return text.split()
  61. def get_vocab(self):
  62. base_vocab = self._token_to_id.copy()
  63. base_vocab.update(self.added_tokens_encoder)
  64. return base_vocab
  65. def token_to_id(self, token: str) -> int:
  66. return self._token_to_id.get(token, self._token_to_id.get(self.unk_token))
  67. def id_to_token(self, index: int) -> str:
  68. return self._id_to_token.get(index, self.unk_token)
  69. def build_inputs_with_special_tokens(
  70. self, token_ids_0: list[int], token_ids_1: list[int] | None = None
  71. ) -> list[int]:
  72. cls = [self.cls_token_id]
  73. sep = [self.eos_token_id] # No sep token in ESM vocabulary
  74. if token_ids_1 is None:
  75. if self.eos_token_id is None:
  76. return cls + token_ids_0
  77. else:
  78. return cls + token_ids_0 + sep
  79. elif self.eos_token_id is None:
  80. raise ValueError("Cannot tokenize multiple sequences when EOS token is not set!")
  81. return cls + token_ids_0 + sep + token_ids_1 + sep # Multiple inputs always have an EOS token
  82. def get_special_tokens_mask(
  83. self, token_ids_0: list, token_ids_1: list | None = None, already_has_special_tokens: bool = False
  84. ) -> list[int]:
  85. """
  86. Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
  87. special tokens using the tokenizer `prepare_for_model` or `encode_plus` methods.
  88. Args:
  89. token_ids_0 (`list[int]`):
  90. List of ids of the first sequence.
  91. token_ids_1 (`list[int]`, *optional*):
  92. List of ids of the second sequence.
  93. already_has_special_tokens (`bool`, *optional*, defaults to `False`):
  94. Whether or not the token list is already formatted with special tokens for the model.
  95. Returns:
  96. A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
  97. """
  98. if already_has_special_tokens:
  99. if token_ids_1 is not None:
  100. raise ValueError(
  101. "You should not supply a second sequence if the provided sequence of "
  102. "ids is already formatted with special tokens for the model."
  103. )
  104. return [1 if token in self.all_special_ids else 0 for token in token_ids_0]
  105. mask = [1] + ([0] * len(token_ids_0)) + [1]
  106. if token_ids_1 is not None:
  107. mask += [0] * len(token_ids_1) + [1]
  108. return mask
  109. def save_vocabulary(self, save_directory, filename_prefix):
  110. vocab_file = os.path.join(save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.txt")
  111. with open(vocab_file, "w") as f:
  112. f.write("\n".join(self.all_tokens))
  113. return (vocab_file,)
  114. @property
  115. def vocab_size(self) -> int:
  116. return len(self.all_tokens)
  117. __all__ = ["EsmTokenizer"]