tokenization_canine.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright Google AI and The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tokenization classes for CANINE."""
  15. from ...tokenization_python import AddedToken, PreTrainedTokenizer
  16. from ...utils import logging
  17. logger = logging.get_logger(__name__)
  18. # Unicode defines 1,114,112 total “codepoints”
  19. UNICODE_VOCAB_SIZE = 1114112
  20. # Below: Constants defining canonical codepoints for special, pseudo-characters.
  21. # Copied from https://github.com/google-research/language/blob/master/language/canine/special_codepoints.py
  22. PAD = 0
  23. CLS = 0xE000
  24. SEP = 0xE001
  25. BOS = 0xE002
  26. MASK = 0xE003
  27. RESERVED = 0xE004
  28. # Maps special codepoints to human-readable names.
  29. SPECIAL_CODEPOINTS: dict[int, str] = {
  30. # Special symbols are represented using codepoints values that are valid,
  31. # but designated as "Private Use", meaning that they will never be assigned
  32. # characters by the Unicode Consortium, and are thus safe for use here.
  33. #
  34. # NOTE: Do *NOT* add any sort of [UNK_CHAR] here. They are explicitly
  35. # excluded and should fail with a hard error.
  36. CLS: "[CLS]",
  37. SEP: "[SEP]",
  38. BOS: "[BOS]",
  39. MASK: "[MASK]",
  40. PAD: "[PAD]",
  41. RESERVED: "[RESERVED]",
  42. }
  43. # Maps special codepoint human-readable names to their codepoint values.
  44. SPECIAL_CODEPOINTS_BY_NAME: dict[str, int] = {name: codepoint for codepoint, name in SPECIAL_CODEPOINTS.items()}
  45. class CanineTokenizer(PreTrainedTokenizer):
  46. r"""
  47. Construct a CANINE tokenizer (i.e. a character splitter). It turns text into a sequence of characters, and then
  48. converts each character into its Unicode code point.
  49. [`CanineTokenizer`] inherits from [`PreTrainedTokenizer`].
  50. Refer to superclass [`PreTrainedTokenizer`] for usage examples and documentation concerning parameters.
  51. Args:
  52. model_max_length (`int`, *optional*, defaults to 2048):
  53. The maximum sentence length the model accepts.
  54. """
  55. model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
  56. def __init__(
  57. self,
  58. bos_token=chr(CLS),
  59. eos_token=chr(SEP),
  60. sep_token=chr(SEP),
  61. cls_token=chr(CLS),
  62. pad_token=chr(PAD),
  63. mask_token=chr(MASK),
  64. add_prefix_space=False,
  65. model_max_length=2048,
  66. **kwargs,
  67. ):
  68. bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token
  69. eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token
  70. sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token
  71. cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token
  72. pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token
  73. # Mask token behave like a normal word, i.e. include the space before it
  74. mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token
  75. # Creates a mapping for looking up the IDs of special symbols.
  76. self._special_codepoints: dict[str, int] = {}
  77. for codepoint, name in SPECIAL_CODEPOINTS.items():
  78. self._special_codepoints[name] = codepoint
  79. # Creates a mapping for looking up the string forms of special symbol IDs.
  80. self._special_codepoint_strings: dict[int, str] = {
  81. codepoint: name for name, codepoint in self._special_codepoints.items()
  82. }
  83. self._unicode_vocab_size = UNICODE_VOCAB_SIZE
  84. self._num_special_tokens = len(self._special_codepoints)
  85. super().__init__(
  86. bos_token=bos_token,
  87. eos_token=eos_token,
  88. sep_token=sep_token,
  89. cls_token=cls_token,
  90. pad_token=pad_token,
  91. mask_token=mask_token,
  92. add_prefix_space=add_prefix_space,
  93. model_max_length=model_max_length,
  94. token_type_ids_pattern="all_zeros",
  95. token_type_ids_include_special_tokens=True,
  96. special_tokens_pattern="cls_sep",
  97. **kwargs,
  98. )
  99. @property
  100. def vocab_size(self) -> int:
  101. return self._unicode_vocab_size
  102. def get_vocab(self):
  103. vocab = {chr(i): i for i in range(self.vocab_size)}
  104. vocab.update(self.added_tokens_encoder)
  105. return vocab
  106. def _tokenize(self, text: str) -> list[str]:
  107. """Tokenize a string (i.e. perform character splitting)."""
  108. return list(text)
  109. def _convert_token_to_id(self, token: str) -> int:
  110. """Converts a token (i.e. a Unicode character) in an id (i.e. its integer Unicode code point value)."""
  111. try:
  112. return ord(token)
  113. except TypeError:
  114. raise ValueError(f"invalid token: '{token}'")
  115. def _convert_id_to_token(self, index: int) -> str:
  116. """
  117. Converts a Unicode code point (integer) in a token (str). In case it's a special code point, convert to
  118. human-readable format.
  119. """
  120. try:
  121. if index in SPECIAL_CODEPOINTS:
  122. return SPECIAL_CODEPOINTS[index]
  123. return chr(index)
  124. except TypeError:
  125. raise ValueError(f"invalid id: {index}")
  126. def convert_tokens_to_string(self, tokens):
  127. return "".join(tokens)
  128. __all__ = ["CanineTokenizer"]