itertools.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. """
  2. Python polyfills for itertools
  3. """
  4. from __future__ import annotations
  5. import itertools
  6. import operator
  7. from collections.abc import Callable
  8. from typing import Optional, overload, TYPE_CHECKING, TypeAlias, TypeVar
  9. from ..decorators import substitute_in_graph
  10. if TYPE_CHECKING:
  11. from collections.abc import Iterable, Iterator
  12. __all__ = [
  13. "accumulate",
  14. "chain",
  15. "chain_from_iterable",
  16. "compress",
  17. "cycle",
  18. "dropwhile",
  19. "filterfalse",
  20. "islice",
  21. "tee",
  22. "zip_longest",
  23. "pairwise",
  24. ]
  25. _T = TypeVar("_T")
  26. _U = TypeVar("_U")
  27. _Predicate: TypeAlias = Callable[[_T], object]
  28. _T1 = TypeVar("_T1")
  29. _T2 = TypeVar("_T2")
  30. # Reference: https://docs.python.org/3/library/itertools.html#itertools.chain
  31. @substitute_in_graph(itertools.chain, is_embedded_type=True) # type: ignore[arg-type]
  32. def chain(*iterables: Iterable[_T]) -> Iterator[_T]:
  33. for iterable in iterables:
  34. yield from iterable
  35. # Reference: https://docs.python.org/3/library/itertools.html#itertools.accumulate
  36. @substitute_in_graph(itertools.accumulate, is_embedded_type=True) # type: ignore[arg-type]
  37. def accumulate(
  38. iterable: Iterable[_T],
  39. func: Optional[Callable[[_T, _T], _T]] = None,
  40. *,
  41. initial: Optional[_T] = None,
  42. ) -> Iterator[_T]:
  43. # call iter outside of the generator to match cypthon behavior
  44. iterator = iter(iterable)
  45. if func is None:
  46. func = operator.add
  47. def _accumulate(iterator: Iterator[_T]) -> Iterator[_T]:
  48. total = initial
  49. if total is None:
  50. try:
  51. total = next(iterator)
  52. except StopIteration:
  53. return
  54. yield total
  55. for element in iterator:
  56. total = func(total, element)
  57. yield total
  58. return _accumulate(iterator)
  59. @substitute_in_graph(itertools.chain.from_iterable) # type: ignore[arg-type]
  60. def chain_from_iterable(iterable: Iterable[Iterable[_T]], /) -> Iterator[_T]:
  61. # previous version of this code was:
  62. # return itertools.chain(*iterable)
  63. # If iterable is an infinite generator, this will lead to infinite recursion
  64. for it in iterable:
  65. yield from it
  66. chain.from_iterable = chain_from_iterable # type: ignore[attr-defined]
  67. # Reference: https://docs.python.org/3/library/itertools.html#itertools.compress
  68. @substitute_in_graph(itertools.compress, is_embedded_type=True) # type: ignore[arg-type]
  69. def compress(data: Iterable[_T], selectors: Iterable[_U], /) -> Iterator[_T]:
  70. return (datum for datum, selector in zip(data, selectors) if selector)
  71. # Reference: https://docs.python.org/3/library/itertools.html#itertools.cycle
  72. @substitute_in_graph(itertools.cycle, is_embedded_type=True) # type: ignore[arg-type]
  73. def cycle(iterable: Iterable[_T]) -> Iterator[_T]:
  74. iterator = iter(iterable)
  75. def _cycle(iterator: Iterator[_T]) -> Iterator[_T]:
  76. # pyrefly: ignore [implicit-any]
  77. saved = []
  78. for element in iterable:
  79. yield element
  80. saved.append(element)
  81. while saved:
  82. for element in saved:
  83. yield element
  84. return _cycle(iterator)
  85. # Reference: https://docs.python.org/3/library/itertools.html#itertools.dropwhile
  86. @substitute_in_graph(itertools.dropwhile, is_embedded_type=True) # type: ignore[arg-type]
  87. def dropwhile(predicate: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
  88. # dropwhile(lambda x: x < 5, [1, 4, 6, 3, 8]) -> 6 3 8
  89. iterator = iter(iterable)
  90. for x in iterator:
  91. if not predicate(x):
  92. yield x
  93. break
  94. yield from iterator
  95. @substitute_in_graph(itertools.filterfalse, is_embedded_type=True) # type: ignore[arg-type]
  96. def filterfalse(function: _Predicate[_T], iterable: Iterable[_T], /) -> Iterator[_T]:
  97. it = iter(iterable)
  98. if function is None:
  99. return filter(operator.not_, it)
  100. else:
  101. return filter(lambda x: not function(x), it)
  102. # Reference: https://docs.python.org/3/library/itertools.html#itertools.islice
  103. @substitute_in_graph(itertools.islice, is_embedded_type=True) # type: ignore[arg-type]
  104. def islice(iterable: Iterable[_T], /, *args: int | None) -> Iterator[_T]:
  105. s = slice(*args)
  106. start = 0 if s.start is None else s.start
  107. stop = s.stop
  108. step = 1 if s.step is None else s.step
  109. if start < 0 or (stop is not None and stop < 0) or step <= 0:
  110. raise ValueError(
  111. "Indices for islice() must be None or an integer: 0 <= x <= sys.maxsize.",
  112. )
  113. if stop is None:
  114. # TODO: use indices = itertools.count() and merge implementation with the else branch
  115. # when we support infinite iterators
  116. next_i = start
  117. for i, element in enumerate(iterable):
  118. if i == next_i:
  119. yield element
  120. next_i += step
  121. else:
  122. indices = range(max(start, stop))
  123. next_i = start
  124. for i, element in zip(indices, iterable):
  125. if i == next_i:
  126. yield element
  127. next_i += step
  128. # Reference: https://docs.python.org/3/library/itertools.html#itertools.pairwise
  129. @substitute_in_graph(itertools.pairwise, is_embedded_type=True) # type: ignore[arg-type]
  130. def pairwise(iterable: Iterable[_T], /) -> Iterator[tuple[_T, _T]]:
  131. a = None
  132. first = True
  133. for b in iterable:
  134. if first:
  135. first = False
  136. else:
  137. yield a, b # type: ignore[misc]
  138. a = b
  139. # Reference: https://docs.python.org/3/library/itertools.html#itertools.tee
  140. @substitute_in_graph(itertools.tee)
  141. def tee(iterable: Iterable[_T], n: int = 2, /) -> tuple[Iterator[_T], ...]:
  142. iterator = iter(iterable)
  143. shared_link = [None, None]
  144. def _tee(link) -> Iterator[_T]: # type: ignore[no-untyped-def]
  145. try:
  146. while True:
  147. if link[1] is None:
  148. link[0] = next(iterator)
  149. link[1] = [None, None]
  150. value, link = link
  151. yield value
  152. except StopIteration:
  153. return
  154. return tuple(_tee(shared_link) for _ in range(n))
  155. @overload
  156. # pyrefly: ignore [inconsistent-overload]
  157. def zip_longest(
  158. iter1: Iterable[_T1],
  159. /,
  160. *,
  161. fillvalue: _U = ...,
  162. ) -> Iterator[tuple[_T1]]: ...
  163. @overload
  164. # pyrefly: ignore [inconsistent-overload]
  165. def zip_longest(
  166. iter1: Iterable[_T1],
  167. iter2: Iterable[_T2],
  168. /,
  169. ) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
  170. @overload
  171. # pyrefly: ignore [inconsistent-overload]
  172. def zip_longest(
  173. iter1: Iterable[_T1],
  174. iter2: Iterable[_T2],
  175. /,
  176. *,
  177. fillvalue: _U = ...,
  178. ) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
  179. @overload
  180. # pyrefly: ignore [inconsistent-overload]
  181. def zip_longest(
  182. iter1: Iterable[_T],
  183. iter2: Iterable[_T],
  184. iter3: Iterable[_T],
  185. /,
  186. *iterables: Iterable[_T],
  187. ) -> Iterator[tuple[_T | None, ...]]: ...
  188. @overload
  189. # pyrefly: ignore [inconsistent-overload]
  190. def zip_longest(
  191. iter1: Iterable[_T],
  192. iter2: Iterable[_T],
  193. iter3: Iterable[_T],
  194. /,
  195. *iterables: Iterable[_T],
  196. fillvalue: _U = ...,
  197. ) -> Iterator[tuple[_T | _U, ...]]: ...
  198. # Reference: https://docs.python.org/3/library/itertools.html#itertools.zip_longest
  199. @substitute_in_graph(itertools.zip_longest, is_embedded_type=True) # type: ignore[arg-type,misc]
  200. def zip_longest(
  201. *iterables: Iterable[_T],
  202. fillvalue: _U = None, # type: ignore[assignment]
  203. ) -> Iterator[tuple[_T | _U, ...]]:
  204. # zip_longest('ABCD', 'xy', fillvalue='-') -> Ax By C- D-
  205. iterators = list(map(iter, iterables))
  206. num_active = len(iterators)
  207. if not num_active:
  208. return
  209. while True:
  210. values = []
  211. for i, iterator in enumerate(iterators):
  212. try:
  213. value = next(iterator)
  214. except StopIteration:
  215. num_active -= 1
  216. if not num_active:
  217. return
  218. iterators[i] = itertools.repeat(fillvalue) # type: ignore[arg-type]
  219. value = fillvalue # type: ignore[assignment]
  220. values.append(value)
  221. yield tuple(values)