tokenization_parakeet.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. from ...tokenization_utils_tokenizers import TokenizersBackend
  16. class ParakeetTokenizer(TokenizersBackend):
  17. """
  18. Inherits all methods from [`PreTrainedTokenizerFast`]. Users should refer to this superclass for more information regarding those methods,
  19. except for `_decode` which is overridden to adapt it to CTC decoding:
  20. 1. Group consecutive tokens
  21. 2. Filter out the blank token
  22. """
  23. def _decode(
  24. self,
  25. token_ids: int | list[int],
  26. skip_special_tokens: bool = False,
  27. clean_up_tokenization_spaces: bool | None = None,
  28. group_tokens: bool = True,
  29. **kwargs,
  30. ) -> str:
  31. if isinstance(token_ids, int):
  32. token_ids = [token_ids]
  33. if group_tokens:
  34. token_ids = [token_group[0] for token_group in itertools.groupby(token_ids)]
  35. # for CTC we filter out the blank token, which is the pad token
  36. token_ids = [token for token in token_ids if token != self.pad_token_id]
  37. return super()._decode(
  38. token_ids=token_ids,
  39. skip_special_tokens=skip_special_tokens,
  40. clean_up_tokenization_spaces=clean_up_tokenization_spaces,
  41. **kwargs,
  42. )
  43. __all__ = ["ParakeetTokenizer"]