mistral.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from tokenizers import Regex, Tokenizer, decoders, pre_tokenizers, processors
  2. from tokenizers.models import BPE
  3. from transformers.convert_slow_tokenizer import bytes_to_unicode
  4. from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast
  5. class MistralConverter:
  6. """
  7. A general tiktoken converter.
  8. """
  9. def __init__(
  10. self,
  11. vocab=None,
  12. pattern=r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""",
  13. add_prefix_space=False,
  14. additional_special_tokens=None,
  15. **kwargs,
  16. ):
  17. self.vocab = vocab
  18. self.pattern = pattern
  19. self.add_prefix_space = add_prefix_space
  20. self.additional_special_tokens = additional_special_tokens
  21. def extract_vocab_merges_from_model(self, vocab: str):
  22. bpe_ranks = vocab
  23. byte_encoder = bytes_to_unicode()
  24. def token_bytes_to_string(b):
  25. return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
  26. merges = []
  27. vocab = {}
  28. for idx, (token, rank) in enumerate(bpe_ranks.items()):
  29. if token not in self.additional_special_tokens:
  30. vocab[token_bytes_to_string(token)] = idx
  31. if len(token) == 1:
  32. continue
  33. local = []
  34. for index in range(1, len(token)):
  35. piece_l, piece_r = token[:index], token[index:]
  36. if piece_l in bpe_ranks and piece_r in bpe_ranks and (piece_l + piece_r) in bpe_ranks:
  37. local.append((piece_l, piece_r, rank))
  38. local = sorted(local, key=lambda x: (bpe_ranks[x[0]], bpe_ranks[x[1]]), reverse=False)
  39. merges.extend(local)
  40. else:
  41. vocab[token] = idx
  42. merges = sorted(merges, key=lambda val: val[2], reverse=False)
  43. merges = [(token_bytes_to_string(val[0]), token_bytes_to_string(val[1])) for val in merges]
  44. return vocab, merges
  45. def tokenizer(self):
  46. vocab_scores, merges = self.extract_vocab_merges_from_model(self.vocab)
  47. tokenizer = Tokenizer(BPE(vocab_scores, merges, fuse_unk=False))
  48. if hasattr(tokenizer.model, "ignore_merges"):
  49. tokenizer.model.ignore_merges = True
  50. return tokenizer
  51. def converted(self) -> Tokenizer:
  52. tokenizer = self.tokenizer()
  53. tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
  54. [
  55. pre_tokenizers.Split(Regex(self.pattern), behavior="isolated", invert=False),
  56. pre_tokenizers.ByteLevel(add_prefix_space=self.add_prefix_space, use_regex=False),
  57. ]
  58. )
  59. tokenizer.decoder = decoders.ByteLevel()
  60. tokenizer.add_special_tokens(self.additional_special_tokens)
  61. tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
  62. return tokenizer
  63. def convert_tekken_tokenizer(tokenizer_file: str):
  64. """Convert a "tekken" tokenizer to a fast Tokenizer."""
  65. # Tekken format -- need to use the Converter
  66. from mistral_common.tokens.tokenizers.base import SpecialTokens
  67. from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
  68. # Load directly using their lib
  69. mistral_tokenizer = MistralTokenizer.from_file(tokenizer_file)
  70. # Extract vocab and special tokens
  71. vocab = mistral_tokenizer.instruct_tokenizer.tokenizer._tekken_token2id_nospecial
  72. sorted_tokens = sorted(mistral_tokenizer.instruct_tokenizer.tokenizer._all_special_tokens, key=lambda x: x["rank"])
  73. all_special = [token["token_str"] for token in sorted_tokens]
  74. specials_tokens = {token: idx for idx, token in enumerate(all_special)}
  75. specials_tokens.update(vocab)
  76. vocab = specials_tokens
  77. # TODO(juliendenize): expose this in mistral-common to avoid accessing private attributes
  78. # and improve maintainability
  79. pattern = mistral_tokenizer.instruct_tokenizer.tokenizer._model._pat_str
  80. # Convert
  81. tokenizer = PreTrainedTokenizerFast(
  82. tokenizer_object=MistralConverter(
  83. vocab=vocab, additional_special_tokens=all_special, pattern=pattern
  84. ).converted()
  85. )
  86. # Post-process
  87. tokenizer.add_special_tokens({"additional_special_tokens": all_special})
  88. MAP_SPECAL = {
  89. "bos_token": SpecialTokens.bos.value,
  90. "eos_token": SpecialTokens.eos.value,
  91. "pad_token": SpecialTokens.pad.value,
  92. "unk_token": SpecialTokens.unk.value,
  93. }
  94. for special_key, special_token in MAP_SPECAL.items():
  95. if special_token in all_special:
  96. tokenizer.add_special_tokens({special_key: special_token})
  97. return tokenizer