_chunk_utils.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # Copyright 2022-present, the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains a utility to iterate by chunks over an iterator."""
  15. import itertools
  16. from collections.abc import Iterable
  17. from typing import TypeVar
  18. T = TypeVar("T")
  19. def chunk_iterable(iterable: Iterable[T], chunk_size: int) -> Iterable[Iterable[T]]:
  20. """Iterates over an iterator chunk by chunk.
  21. Taken from https://stackoverflow.com/a/8998040.
  22. See also https://github.com/huggingface/huggingface_hub/pull/920#discussion_r938793088.
  23. Args:
  24. iterable (`Iterable`):
  25. The iterable on which we want to iterate.
  26. chunk_size (`int`):
  27. Size of the chunks. Must be a strictly positive integer (e.g. >0).
  28. Example:
  29. ```python
  30. >>> from huggingface_hub.utils import chunk_iterable
  31. >>> for items in chunk_iterable(range(17), chunk_size=8):
  32. ... print(items)
  33. # [0, 1, 2, 3, 4, 5, 6, 7]
  34. # [8, 9, 10, 11, 12, 13, 14, 15]
  35. # [16] # smaller last chunk
  36. ```
  37. Raises:
  38. [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  39. If `chunk_size` <= 0.
  40. > [!WARNING]
  41. > The last chunk can be smaller than `chunk_size`.
  42. """
  43. if not isinstance(chunk_size, int) or chunk_size <= 0:
  44. raise ValueError("`chunk_size` must be a strictly positive integer (>0).")
  45. iterator = iter(iterable)
  46. while True:
  47. try:
  48. next_item = next(iterator)
  49. except StopIteration:
  50. return
  51. yield itertools.chain((next_item,), itertools.islice(iterator, chunk_size - 1))