iter.py 46 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286
  1. import collections
  2. import random
  3. import threading
  4. import time
  5. from contextlib import contextmanager
  6. from typing import Any, Callable, Generic, Iterable, List, TypeVar
  7. import ray
  8. from ray.util.annotations import Deprecated
  9. from ray.util.iter_metrics import MetricsContext, SharedMetrics
  10. # The type of an iterator element.
  11. T = TypeVar("T")
  12. U = TypeVar("U")
  13. @Deprecated
  14. def from_items(
  15. items: List[T], num_shards: int = 2, repeat: bool = False
  16. ) -> "ParallelIterator[T]":
  17. """Create a parallel iterator from an existing set of objects.
  18. The objects will be divided round-robin among the number of shards.
  19. Args:
  20. items: The list of items to iterate over.
  21. num_shards: The number of worker actors to create.
  22. repeat: Whether to cycle over the items forever.
  23. """
  24. shards = [[] for _ in range(num_shards)]
  25. for i, item in enumerate(items):
  26. shards[i % num_shards].append(item)
  27. name = "from_items[{}, {}, shards={}{}]".format(
  28. items and type(items[0]).__name__ or "None",
  29. len(items),
  30. num_shards,
  31. ", repeat=True" if repeat else "",
  32. )
  33. return from_iterators(shards, repeat=repeat, name=name)
  34. @Deprecated
  35. def from_range(
  36. n: int, num_shards: int = 2, repeat: bool = False
  37. ) -> "ParallelIterator[int]":
  38. """Create a parallel iterator over the range 0..n.
  39. The range will be partitioned sequentially among the number of shards.
  40. Args:
  41. n: The max end of the range of numbers.
  42. num_shards: The number of worker actors to create.
  43. repeat: Whether to cycle over the range forever.
  44. """
  45. generators = []
  46. shard_size = n // num_shards
  47. for i in range(num_shards):
  48. start = i * shard_size
  49. if i == num_shards - 1:
  50. end = n
  51. else:
  52. end = (i + 1) * shard_size
  53. generators.append(range(start, end))
  54. name = (
  55. f"from_range[{n}, shards={num_shards}" f"{', repeat=True' if repeat else ''}]"
  56. )
  57. return from_iterators(
  58. generators,
  59. repeat=repeat,
  60. name=name,
  61. )
  62. @Deprecated
  63. def from_iterators(
  64. generators: List[Iterable[T]], repeat: bool = False, name=None
  65. ) -> "ParallelIterator[T]":
  66. """Create a parallel iterator from a list of iterables.
  67. An iterable can be a conatiner (list, str, tuple, set, etc.),
  68. a generator, or a custom class that implements __iter__ or __getitem__.
  69. An actor will be created for each iterable.
  70. Examples:
  71. >>> # Create using a list of generators.
  72. >>> from_iterators([range(100), range(100)])
  73. >>> # Certain generators are not serializable.
  74. >>> from_iterators([(x for x in range(100))])
  75. ... TypeError: can't pickle generator objects
  76. >>> # So use lambda functions instead.
  77. >>> # Lambda functions are serializable.
  78. >>> from_iterators([lambda: (x for x in range(100))])
  79. Args:
  80. generators: A list of Python iterables or lambda
  81. functions that produce an iterable when called. We allow lambda
  82. functions since certain generators might not be serializable,
  83. but a lambda that returns it can be.
  84. repeat: Whether to cycle over the iterators forever.
  85. name: Optional name to give the iterator.
  86. """
  87. worker_cls = ray.remote(ParallelIteratorWorker)
  88. actors = [worker_cls.remote(g, repeat) for g in generators]
  89. if not name:
  90. name = "from_iterators[shards={}{}]".format(
  91. len(generators), ", repeat=True" if repeat else ""
  92. )
  93. return from_actors(actors, name=name)
  94. @Deprecated
  95. def from_actors(
  96. actors: List["ray.actor.ActorHandle"], name=None
  97. ) -> "ParallelIterator[T]":
  98. """Create a parallel iterator from an existing set of actors.
  99. Each actor must subclass the ParallelIteratorWorker interface.
  100. Args:
  101. actors: List of actors that each implement
  102. ParallelIteratorWorker.
  103. name: Optional name to give the iterator.
  104. """
  105. if not name:
  106. name = f"from_actors[shards={len(actors)}]"
  107. return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[])
  108. @Deprecated
  109. class ParallelIterator(Generic[T]):
  110. """A parallel iterator over a set of remote actors.
  111. This can be used to iterate over a fixed set of task results
  112. (like an actor pool), or a stream of data (e.g., a fixed range of numbers,
  113. an infinite stream of RLlib rollout results).
  114. This class is **serializable** and can be passed to other remote
  115. tasks and actors. However, each shard should be read from at most one
  116. process at a time.
  117. Examples:
  118. >>> # Applying a function over items in parallel.
  119. >>> it = ray.util.iter.from_items([1, 2, 3], num_shards=2)
  120. ... <__main__.ParallelIterator object>
  121. >>> it = it.for_each(lambda x: x * 2).gather_sync()
  122. ... <__main__.LocalIterator object>
  123. >>> print(list(it))
  124. ... [2, 4, 6]
  125. >>> # Creating from generators.
  126. >>> it = ray.util.iter.from_iterators([range(3), range(3)])
  127. ... <__main__.ParallelIterator object>
  128. >>> print(list(it.gather_sync()))
  129. ... [0, 0, 1, 1, 2, 2]
  130. >>> # Accessing the individual shards of an iterator.
  131. >>> it = ray.util.iter.from_range(10, num_shards=2)
  132. ... <__main__.ParallelIterator object>
  133. >>> it0 = it.get_shard(0)
  134. ... <__main__.LocalIterator object>
  135. >>> print(list(it0))
  136. ... [0, 1, 2, 3, 4]
  137. >>> it1 = it.get_shard(1)
  138. ... <__main__.LocalIterator object>
  139. >>> print(list(it1))
  140. ... [5, 6, 7, 8, 9]
  141. >>> # Gathering results from actors synchronously in parallel.
  142. >>> it = ray.util.iter.from_actors(workers)
  143. ... <__main__.ParallelIterator object>
  144. >>> it = it.batch_across_shards()
  145. ... <__main__.LocalIterator object>
  146. >>> print(next(it))
  147. ... [worker_1_result_1, worker_2_result_1]
  148. >>> print(next(it))
  149. ... [worker_1_result_2, worker_2_result_2]
  150. """
  151. def __init__(
  152. self,
  153. actor_sets: List["_ActorSet"],
  154. name: str,
  155. parent_iterators: List["ParallelIterator[Any]"],
  156. ):
  157. """Create a parallel iterator (this is an internal function)."""
  158. # We track multiple sets of actors to support parallel .union().
  159. self.actor_sets = actor_sets
  160. self.name = name
  161. # keep explicit reference to parent iterator for repartition
  162. self.parent_iterators = parent_iterators
  163. def __iter__(self):
  164. raise TypeError(
  165. "You must use it.gather_sync() or it.gather_async() to "
  166. "iterate over the results of a ParallelIterator."
  167. )
  168. def __str__(self):
  169. return repr(self)
  170. def __repr__(self):
  171. return f"ParallelIterator[{self.name}]"
  172. def _with_transform(self, local_it_fn, name):
  173. """Helper function to create new Parallel Iterator"""
  174. return ParallelIterator(
  175. [a.with_transform(local_it_fn) for a in self.actor_sets],
  176. name=self.name + name,
  177. parent_iterators=self.parent_iterators,
  178. )
  179. def transform(
  180. self, fn: Callable[[Iterable[T]], Iterable[U]]
  181. ) -> "ParallelIterator[U]":
  182. """Remotely transform the iterator.
  183. This is advanced version of for_each that allows you to apply arbitrary
  184. generator transformations over the iterator. Prefer to use .for_each()
  185. when possible for simplicity.
  186. Args:
  187. fn: function to use to transform the iterator. The function
  188. should pass through instances of _NextValueNotReady that appear
  189. in its input iterator. Note that this function is only called
  190. **once** over the input iterator.
  191. Returns:
  192. ParallelIterator[U]: a parallel iterator.
  193. Examples:
  194. >>> def f(it):
  195. ... for x in it:
  196. ... if x % 2 == 0:
  197. ... yield x
  198. >>> from_range(10, 1).transform(f).gather_sync().take(5)
  199. ... [0, 2, 4, 6, 8]
  200. """
  201. return self._with_transform(
  202. lambda local_it: local_it.transform(fn), ".transform()"
  203. )
  204. def for_each(
  205. self, fn: Callable[[T], U], max_concurrency=1, resources=None
  206. ) -> "ParallelIterator[U]":
  207. """Remotely apply fn to each item in this iterator.
  208. If `max_concurrency` == 1 then `fn` will be executed serially by each
  209. shards
  210. `max_concurrency` should be used to achieve a high degree of
  211. parallelism without the overhead of increasing the number of shards
  212. (which are actor based). If `max_concurrency` is not 1, this function
  213. provides no semantic guarantees on the output order.
  214. Results will be returned as soon as they are ready.
  215. A performance note: When executing concurrently, this function
  216. maintains its own internal buffer. If `num_async` is `n` and
  217. max_concur is `k` then the total number of buffered objects could be up
  218. to `n + k - 1`
  219. Args:
  220. fn: function to apply to each item.
  221. max_concurrency: max number of concurrent calls to fn per
  222. shard. If 0, then apply all operations concurrently.
  223. resources: resources that the function requires to execute.
  224. This has the same default as `ray.remote` and is only used
  225. when `max_concurrency > 1`.
  226. Returns:
  227. ParallelIterator[U]: a parallel iterator whose elements have `fn`
  228. applied.
  229. Examples:
  230. >>> next(from_range(4).for_each(
  231. lambda x: x * 2,
  232. max_concur=2,
  233. resources={"num_cpus": 0.1}).gather_sync()
  234. )
  235. ... [0, 2, 4, 8]
  236. """
  237. assert max_concurrency >= 0, "max_concurrency must be non-negative."
  238. return self._with_transform(
  239. lambda local_it: local_it.for_each(fn, max_concurrency, resources),
  240. ".for_each()",
  241. )
  242. def filter(self, fn: Callable[[T], bool]) -> "ParallelIterator[T]":
  243. """Remotely filter items from this iterator.
  244. Args:
  245. fn: returns False for items to drop from the iterator.
  246. Examples:
  247. >>> it = from_items([0, 1, 2]).filter(lambda x: x > 0)
  248. >>> next(it.gather_sync())
  249. ... [1, 2]
  250. """
  251. return self._with_transform(lambda local_it: local_it.filter(fn), ".filter()")
  252. def batch(self, n: int) -> "ParallelIterator[List[T]]":
  253. """Remotely batch together items in this iterator.
  254. Args:
  255. n: Number of items to batch together.
  256. Examples:
  257. >>> next(from_range(10, 1).batch(4).gather_sync())
  258. ... [0, 1, 2, 3]
  259. """
  260. return self._with_transform(lambda local_it: local_it.batch(n), f".batch({n})")
  261. def flatten(self) -> "ParallelIterator[T[0]]":
  262. """Flatten batches of items into individual items.
  263. Examples:
  264. >>> next(from_range(10, 1).batch(4).flatten())
  265. ... 0
  266. """
  267. return self._with_transform(lambda local_it: local_it.flatten(), ".flatten()")
  268. def combine(self, fn: Callable[[T], List[U]]) -> "ParallelIterator[U]":
  269. """Transform and then combine items horizontally.
  270. This is the equivalent of for_each(fn).flatten() (flat map).
  271. """
  272. it = self.for_each(fn).flatten()
  273. it.name = self.name + ".combine()"
  274. return it
  275. def local_shuffle(
  276. self, shuffle_buffer_size: int, seed: int = None
  277. ) -> "ParallelIterator[T]":
  278. """Remotely shuffle items of each shard independently
  279. Args:
  280. shuffle_buffer_size: The algorithm fills a buffer with
  281. shuffle_buffer_size elements and randomly samples elements from
  282. this buffer, replacing the selected elements with new elements.
  283. For perfect shuffling, this argument should be greater than or
  284. equal to the largest iterator size.
  285. seed: Seed to use for
  286. randomness. Default value is None.
  287. Returns:
  288. A ParallelIterator with a local shuffle applied on the base
  289. iterator
  290. Examples:
  291. >>> it = from_range(10, 1).local_shuffle(shuffle_buffer_size=2)
  292. >>> it = it.gather_sync()
  293. >>> next(it)
  294. 0
  295. >>> next(it)
  296. 2
  297. >>> next(it)
  298. 3
  299. >>> next(it)
  300. 1
  301. """
  302. return self._with_transform(
  303. lambda local_it: local_it.shuffle(shuffle_buffer_size, seed),
  304. ".local_shuffle(shuffle_buffer_size={}, seed={})".format(
  305. shuffle_buffer_size, str(seed) if seed is not None else "None"
  306. ),
  307. )
  308. def repartition(
  309. self, num_partitions: int, batch_ms: int = 0
  310. ) -> "ParallelIterator[T]":
  311. """Returns a new ParallelIterator instance with num_partitions shards.
  312. The new iterator contains the same data in this instance except with
  313. num_partitions shards. The data is split in round-robin fashion for
  314. the new ParallelIterator.
  315. Args:
  316. num_partitions: The number of shards to use for the new
  317. ParallelIterator
  318. batch_ms: Batches items for batch_ms milliseconds
  319. on each shard before retrieving it.
  320. Increasing batch_ms increases latency but improves throughput.
  321. Returns:
  322. A ParallelIterator with num_partitions number of shards and the
  323. data of this ParallelIterator split round-robin among the new
  324. number of shards.
  325. Examples:
  326. >>> it = from_range(8, 2)
  327. >>> it = it.repartition(3)
  328. >>> list(it.get_shard(0))
  329. [0, 4, 3, 7]
  330. >>> list(it.get_shard(1))
  331. [1, 5]
  332. >>> list(it.get_shard(2))
  333. [2, 6]
  334. """
  335. # initialize the local iterators for all the actors
  336. all_actors = []
  337. for actor_set in self.actor_sets:
  338. actor_set.init_actors()
  339. all_actors.extend(actor_set.actors)
  340. def base_iterator(num_partitions, partition_index, timeout=None):
  341. futures = {}
  342. for a in all_actors:
  343. futures[
  344. a.par_iter_slice_batch.remote(
  345. step=num_partitions, start=partition_index, batch_ms=batch_ms
  346. )
  347. ] = a
  348. while futures:
  349. pending = list(futures)
  350. if timeout is None:
  351. # First try to do a batch wait for efficiency.
  352. ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0)
  353. # Fall back to a blocking wait.
  354. if not ready:
  355. ready, _ = ray.wait(pending, num_returns=1)
  356. else:
  357. ready, _ = ray.wait(
  358. pending, num_returns=len(pending), timeout=timeout
  359. )
  360. for obj_ref in ready:
  361. actor = futures.pop(obj_ref)
  362. try:
  363. batch = ray.get(obj_ref)
  364. futures[
  365. actor.par_iter_slice_batch.remote(
  366. step=num_partitions,
  367. start=partition_index,
  368. batch_ms=batch_ms,
  369. )
  370. ] = actor
  371. for item in batch:
  372. yield item
  373. except StopIteration:
  374. pass
  375. # Always yield after each round of wait with timeout.
  376. if timeout is not None:
  377. yield _NextValueNotReady()
  378. def make_gen_i(i):
  379. return lambda: base_iterator(num_partitions, i)
  380. name = self.name + f".repartition[num_partitions={num_partitions}]"
  381. generators = [make_gen_i(s) for s in range(num_partitions)]
  382. worker_cls = ray.remote(ParallelIteratorWorker)
  383. actors = [worker_cls.remote(g, repeat=False) for g in generators]
  384. # need explicit reference to self so actors in this instance do not die
  385. return ParallelIterator([_ActorSet(actors, [])], name, parent_iterators=[self])
  386. def gather_sync(self) -> "LocalIterator[T]":
  387. """Returns a local iterable for synchronous iteration.
  388. New items will be fetched from the shards on-demand as the iterator
  389. is stepped through.
  390. This is the equivalent of batch_across_shards().flatten().
  391. Examples:
  392. >>> it = from_range(100, 1).gather_sync()
  393. >>> next(it)
  394. ... 0
  395. >>> next(it)
  396. ... 1
  397. >>> next(it)
  398. ... 2
  399. """
  400. it = self.batch_across_shards().flatten()
  401. it.name = f"{self}.gather_sync()"
  402. return it
  403. def batch_across_shards(self) -> "LocalIterator[List[T]]":
  404. """Iterate over the results of multiple shards in parallel.
  405. Examples:
  406. >>> it = from_iterators([range(3), range(3)])
  407. >>> next(it.batch_across_shards())
  408. ... [0, 0]
  409. """
  410. def base_iterator(timeout=None):
  411. active = []
  412. for actor_set in self.actor_sets:
  413. actor_set.init_actors()
  414. active.extend(actor_set.actors)
  415. futures = [a.par_iter_next.remote() for a in active]
  416. while active:
  417. try:
  418. yield ray.get(futures, timeout=timeout)
  419. futures = [a.par_iter_next.remote() for a in active]
  420. # Always yield after each round of gets with timeout.
  421. if timeout is not None:
  422. yield _NextValueNotReady()
  423. except TimeoutError:
  424. yield _NextValueNotReady()
  425. except StopIteration:
  426. # Find and remove the actor that produced StopIteration.
  427. results = []
  428. for a, f in zip(list(active), futures):
  429. try:
  430. results.append(ray.get(f))
  431. except StopIteration:
  432. active.remove(a)
  433. if results:
  434. yield results
  435. futures = [a.par_iter_next.remote() for a in active]
  436. name = f"{self}.batch_across_shards()"
  437. return LocalIterator(base_iterator, SharedMetrics(), name=name)
  438. def gather_async(self, batch_ms=0, num_async=1) -> "LocalIterator[T]":
  439. """Returns a local iterable for asynchronous iteration.
  440. New items will be fetched from the shards asynchronously as soon as
  441. the previous one is computed. Items arrive in non-deterministic order.
  442. Arguments:
  443. batch_ms: Batches items for batch_ms milliseconds
  444. on each shard before retrieving it.
  445. Increasing batch_ms increases latency but improves throughput.
  446. If this value is 0, then items are returned immediately.
  447. num_async: The max number of async requests in flight
  448. per actor. Increasing this improves the amount of pipeline
  449. parallelism in the iterator.
  450. Examples:
  451. >>> it = from_range(100, 1).gather_async()
  452. >>> next(it)
  453. ... 3
  454. >>> next(it)
  455. ... 0
  456. >>> next(it)
  457. ... 1
  458. """
  459. if num_async < 1:
  460. raise ValueError("queue depth must be positive")
  461. if batch_ms < 0:
  462. raise ValueError("batch time must be positive")
  463. # Forward reference to the returned iterator.
  464. local_iter = None
  465. def base_iterator(timeout=None):
  466. all_actors = []
  467. for actor_set in self.actor_sets:
  468. actor_set.init_actors()
  469. all_actors.extend(actor_set.actors)
  470. futures = {}
  471. for _ in range(num_async):
  472. for a in all_actors:
  473. futures[a.par_iter_next_batch.remote(batch_ms)] = a
  474. while futures:
  475. pending = list(futures)
  476. if timeout is None:
  477. # First try to do a batch wait for efficiency.
  478. ready, _ = ray.wait(pending, num_returns=len(pending), timeout=0)
  479. # Fall back to a blocking wait.
  480. if not ready:
  481. ready, _ = ray.wait(pending, num_returns=1)
  482. else:
  483. ready, _ = ray.wait(
  484. pending, num_returns=len(pending), timeout=timeout
  485. )
  486. for obj_ref in ready:
  487. actor = futures.pop(obj_ref)
  488. try:
  489. local_iter.shared_metrics.get().current_actor = actor
  490. batch = ray.get(obj_ref)
  491. futures[actor.par_iter_next_batch.remote(batch_ms)] = actor
  492. for item in batch:
  493. yield item
  494. except StopIteration:
  495. pass
  496. # Always yield after each round of wait with timeout.
  497. if timeout is not None:
  498. yield _NextValueNotReady()
  499. name = f"{self}.gather_async()"
  500. local_iter = LocalIterator(base_iterator, SharedMetrics(), name=name)
  501. return local_iter
  502. def take(self, n: int) -> List[T]:
  503. """Return up to the first n items from this iterator."""
  504. return self.gather_sync().take(n)
  505. def show(self, n: int = 20):
  506. """Print up to the first n items from this iterator."""
  507. return self.gather_sync().show(n)
  508. def union(self, other: "ParallelIterator[T]") -> "ParallelIterator[T]":
  509. """Return an iterator that is the union of this and the other."""
  510. if not isinstance(other, ParallelIterator):
  511. raise TypeError(
  512. f"other must be of type ParallelIterator, got {type(other)}"
  513. )
  514. actor_sets = []
  515. actor_sets.extend(self.actor_sets)
  516. actor_sets.extend(other.actor_sets)
  517. # if one of these iterators is a result of a repartition, we need to
  518. # keep an explicit reference to its parent iterator
  519. return ParallelIterator(
  520. actor_sets,
  521. f"ParallelUnion[{self}, {other}]",
  522. parent_iterators=self.parent_iterators + other.parent_iterators,
  523. )
  524. def select_shards(self, shards_to_keep: List[int]) -> "ParallelIterator[T]":
  525. """Return a child iterator that only iterates over given shards.
  526. It is the user's responsibility to ensure child iterators are operating
  527. over disjoint sub-sets of this iterator's shards.
  528. """
  529. if len(self.actor_sets) > 1:
  530. raise ValueError("select_shards() is not allowed after union()")
  531. if len(shards_to_keep) == 0:
  532. raise ValueError("at least one shard must be selected")
  533. old_actor_set = self.actor_sets[0]
  534. new_actors = [
  535. a for (i, a) in enumerate(old_actor_set.actors) if i in shards_to_keep
  536. ]
  537. assert len(new_actors) == len(shards_to_keep), "Invalid actor index"
  538. new_actor_set = _ActorSet(new_actors, old_actor_set.transforms)
  539. return ParallelIterator(
  540. [new_actor_set],
  541. f"{self}.select_shards({len(shards_to_keep)} total)",
  542. parent_iterators=self.parent_iterators,
  543. )
  544. def num_shards(self) -> int:
  545. """Return the number of worker actors backing this iterator."""
  546. return sum(len(a.actors) for a in self.actor_sets)
  547. def shards(self) -> List["LocalIterator[T]"]:
  548. """Return the list of all shards."""
  549. return [self.get_shard(i) for i in range(self.num_shards())]
  550. def get_shard(
  551. self, shard_index: int, batch_ms: int = 0, num_async: int = 1
  552. ) -> "LocalIterator[T]":
  553. """Return a local iterator for the given shard.
  554. The iterator is guaranteed to be serializable and can be passed to
  555. remote tasks or actors.
  556. Arguments:
  557. shard_index: Index of the shard to gather.
  558. batch_ms: Batches items for batch_ms milliseconds
  559. before retrieving it.
  560. Increasing batch_ms increases latency but improves throughput.
  561. If this value is 0, then items are returned immediately.
  562. num_async: The max number of requests in flight.
  563. Increasing this improves the amount of pipeline
  564. parallelism in the iterator.
  565. """
  566. if num_async < 1:
  567. raise ValueError("num async must be positive")
  568. if batch_ms < 0:
  569. raise ValueError("batch time must be positive")
  570. a, t = None, None
  571. i = shard_index
  572. for actor_set in self.actor_sets:
  573. if i < len(actor_set.actors):
  574. a = actor_set.actors[i]
  575. t = actor_set.transforms
  576. break
  577. else:
  578. i -= len(actor_set.actors)
  579. if a is None:
  580. raise ValueError("Shard index out of range", shard_index, self.num_shards())
  581. def base_iterator(timeout=None):
  582. queue = collections.deque()
  583. ray.get(a.par_iter_init.remote(t))
  584. for _ in range(num_async):
  585. queue.append(a.par_iter_next_batch.remote(batch_ms))
  586. while True:
  587. try:
  588. batch = ray.get(queue.popleft(), timeout=timeout)
  589. queue.append(a.par_iter_next_batch.remote(batch_ms))
  590. for item in batch:
  591. yield item
  592. # Always yield after each round of gets with timeout.
  593. if timeout is not None:
  594. yield _NextValueNotReady()
  595. except TimeoutError:
  596. yield _NextValueNotReady()
  597. except StopIteration:
  598. break
  599. name = self.name + f".shard[{shard_index}]"
  600. return LocalIterator(base_iterator, SharedMetrics(), name=name)
  601. @Deprecated
  602. class LocalIterator(Generic[T]):
  603. """An iterator over a single shard of data.
  604. It implements similar transformations as ParallelIterator[T], but the
  605. transforms will be applied locally and not remotely in parallel.
  606. This class is **serializable** and can be passed to other remote
  607. tasks and actors. However, it should be read from at most one process at
  608. a time."""
  609. # If a function passed to LocalIterator.for_each() has this method,
  610. # we will call it at the beginning of each data fetch call. This can be
  611. # used to measure the underlying wait latency for measurement purposes.
  612. ON_FETCH_START_HOOK_NAME = "_on_fetch_start"
  613. thread_local = threading.local()
  614. def __init__(
  615. self,
  616. base_iterator: Callable[[], Iterable[T]],
  617. shared_metrics: SharedMetrics,
  618. local_transforms: List[Callable[[Iterable], Any]] = None,
  619. timeout: int = None,
  620. name=None,
  621. ):
  622. """Create a local iterator (this is an internal function).
  623. Args:
  624. base_iterator: A function that produces the base iterator.
  625. This is a function so that we can ensure LocalIterator is
  626. serializable.
  627. shared_metrics: Existing metrics context or a new
  628. context. Should be the same for each chained iterator.
  629. local_transforms: A list of transformation functions to be
  630. applied on top of the base iterator. When iteration begins, we
  631. create the base iterator and apply these functions. This lazy
  632. creation ensures LocalIterator is serializable until you start
  633. iterating over it.
  634. timeout: Optional timeout in seconds for this iterator, after
  635. which _NextValueNotReady will be returned. This avoids
  636. blocking.
  637. name: Optional name for this iterator.
  638. """
  639. assert isinstance(shared_metrics, SharedMetrics)
  640. self.base_iterator = base_iterator
  641. self.built_iterator = None
  642. self.local_transforms = local_transforms or []
  643. self.shared_metrics = shared_metrics
  644. self.timeout = timeout
  645. self.name = name or "unknown"
  646. @staticmethod
  647. def get_metrics() -> MetricsContext:
  648. """Return the current metrics context.
  649. This can only be called within an iterator function."""
  650. if (
  651. not hasattr(LocalIterator.thread_local, "metrics")
  652. or LocalIterator.thread_local.metrics is None
  653. ):
  654. raise ValueError("Cannot access context outside an iterator.")
  655. return LocalIterator.thread_local.metrics
  656. def _build_once(self):
  657. if self.built_iterator is None:
  658. it = iter(self.base_iterator(self.timeout))
  659. for fn in self.local_transforms:
  660. it = fn(it)
  661. self.built_iterator = it
  662. @contextmanager
  663. def _metrics_context(self):
  664. self.thread_local.metrics = self.shared_metrics.get()
  665. yield
  666. def __iter__(self):
  667. self._build_once()
  668. return self.built_iterator
  669. def __next__(self):
  670. self._build_once()
  671. return next(self.built_iterator)
  672. def __str__(self):
  673. return repr(self)
  674. def __repr__(self):
  675. return f"LocalIterator[{self.name}]"
  676. def transform(self, fn: Callable[[Iterable[T]], Iterable[U]]) -> "LocalIterator[U]":
  677. # TODO(ekl) can we automatically handle NextValueNotReady here?
  678. def apply_transform(it):
  679. for item in fn(it):
  680. yield item
  681. return LocalIterator(
  682. self.base_iterator,
  683. self.shared_metrics,
  684. self.local_transforms + [apply_transform],
  685. name=self.name + ".transform()",
  686. )
  687. def for_each(
  688. self, fn: Callable[[T], U], max_concurrency=1, resources=None
  689. ) -> "LocalIterator[U]":
  690. if max_concurrency == 1:
  691. def apply_foreach(it):
  692. for item in it:
  693. if isinstance(item, _NextValueNotReady):
  694. yield item
  695. else:
  696. # Keep retrying the function until it returns a valid
  697. # value. This allows for non-blocking functions.
  698. while True:
  699. with self._metrics_context():
  700. result = fn(item)
  701. yield result
  702. if not isinstance(result, _NextValueNotReady):
  703. break
  704. else:
  705. if resources is None:
  706. resources = {}
  707. def apply_foreach(it):
  708. cur = []
  709. remote = ray.remote(fn).options(**resources)
  710. remote_fn = remote.remote
  711. for item in it:
  712. if isinstance(item, _NextValueNotReady):
  713. yield item
  714. else:
  715. if max_concurrency and len(cur) >= max_concurrency:
  716. finished, cur = ray.wait(cur)
  717. yield from ray.get(finished)
  718. cur.append(remote_fn(item))
  719. while cur:
  720. finished, cur = ray.wait(cur)
  721. yield from ray.get(finished)
  722. if hasattr(fn, LocalIterator.ON_FETCH_START_HOOK_NAME):
  723. unwrapped = apply_foreach
  724. def add_wait_hooks(it):
  725. it = unwrapped(it)
  726. new_item = True
  727. while True:
  728. # Avoids calling on_fetch_start repeatedly if we are
  729. # yielding _NextValueNotReady.
  730. if new_item:
  731. with self._metrics_context():
  732. fn._on_fetch_start()
  733. new_item = False
  734. item = next(it)
  735. if not isinstance(item, _NextValueNotReady):
  736. new_item = True
  737. yield item
  738. apply_foreach = add_wait_hooks
  739. return LocalIterator(
  740. self.base_iterator,
  741. self.shared_metrics,
  742. self.local_transforms + [apply_foreach],
  743. name=self.name + ".for_each()",
  744. )
  745. def filter(self, fn: Callable[[T], bool]) -> "LocalIterator[T]":
  746. def apply_filter(it):
  747. for item in it:
  748. with self._metrics_context():
  749. if isinstance(item, _NextValueNotReady) or fn(item):
  750. yield item
  751. return LocalIterator(
  752. self.base_iterator,
  753. self.shared_metrics,
  754. self.local_transforms + [apply_filter],
  755. name=self.name + ".filter()",
  756. )
  757. def batch(self, n: int) -> "LocalIterator[List[T]]":
  758. def apply_batch(it):
  759. batch = []
  760. for item in it:
  761. if isinstance(item, _NextValueNotReady):
  762. yield item
  763. else:
  764. batch.append(item)
  765. if len(batch) >= n:
  766. yield batch
  767. batch = []
  768. if batch:
  769. yield batch
  770. return LocalIterator(
  771. self.base_iterator,
  772. self.shared_metrics,
  773. self.local_transforms + [apply_batch],
  774. name=self.name + f".batch({n})",
  775. )
  776. def flatten(self) -> "LocalIterator[T[0]]":
  777. def apply_flatten(it):
  778. for item in it:
  779. if isinstance(item, _NextValueNotReady):
  780. yield item
  781. else:
  782. for subitem in item:
  783. yield subitem
  784. return LocalIterator(
  785. self.base_iterator,
  786. self.shared_metrics,
  787. self.local_transforms + [apply_flatten],
  788. name=self.name + ".flatten()",
  789. )
  790. def shuffle(self, shuffle_buffer_size: int, seed: int = None) -> "LocalIterator[T]":
  791. """Shuffle items of this iterator
  792. Args:
  793. shuffle_buffer_size: The algorithm fills a buffer with
  794. shuffle_buffer_size elements and randomly samples elements from
  795. this buffer, replacing the selected elements with new elements.
  796. For perfect shuffling, this argument should be greater than or
  797. equal to the largest iterator size.
  798. seed: Seed to use for
  799. randomness. Default value is None.
  800. Returns:
  801. A new LocalIterator with shuffling applied
  802. """
  803. shuffle_random = random.Random(seed)
  804. def apply_shuffle(it):
  805. buffer = []
  806. for item in it:
  807. if isinstance(item, _NextValueNotReady):
  808. yield item
  809. else:
  810. buffer.append(item)
  811. if len(buffer) >= shuffle_buffer_size:
  812. yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
  813. while len(buffer) > 0:
  814. yield buffer.pop(shuffle_random.randint(0, len(buffer) - 1))
  815. return LocalIterator(
  816. self.base_iterator,
  817. self.shared_metrics,
  818. self.local_transforms + [apply_shuffle],
  819. name=self.name
  820. + ".shuffle(shuffle_buffer_size={}, seed={})".format(
  821. shuffle_buffer_size, str(seed) if seed is not None else "None"
  822. ),
  823. )
  824. def combine(self, fn: Callable[[T], List[U]]) -> "LocalIterator[U]":
  825. it = self.for_each(fn).flatten()
  826. it.name = self.name + ".combine()"
  827. return it
  828. def zip_with_source_actor(self):
  829. def zip_with_source(item):
  830. metrics = LocalIterator.get_metrics()
  831. if metrics.current_actor is None:
  832. raise ValueError("Could not identify source actor of item")
  833. return metrics.current_actor, item
  834. it = self.for_each(zip_with_source)
  835. it.name = self.name + ".zip_with_source_actor()"
  836. return it
  837. def take(self, n: int) -> List[T]:
  838. """Return up to the first n items from this iterator."""
  839. out = []
  840. for item in self:
  841. out.append(item)
  842. if len(out) >= n:
  843. break
  844. return out
  845. def show(self, n: int = 20):
  846. """Print up to the first n items from this iterator."""
  847. i = 0
  848. for item in self:
  849. print(item)
  850. i += 1
  851. if i >= n:
  852. break
  853. def duplicate(self, n) -> List["LocalIterator[T]"]:
  854. """Copy this iterator `n` times, duplicating the data.
  855. The child iterators will be prioritized by how much of the parent
  856. stream they have consumed. That is, we will not allow children to fall
  857. behind, since that can cause infinite memory buildup in this operator.
  858. Returns:
  859. List[LocalIterator[T]]: child iterators that each have a copy
  860. of the data of this iterator.
  861. """
  862. if n < 2:
  863. raise ValueError("Number of copies must be >= 2")
  864. queues = []
  865. for _ in range(n):
  866. queues.append(collections.deque())
  867. def fill_next(timeout):
  868. self.timeout = timeout
  869. item = next(self)
  870. for q in queues:
  871. q.append(item)
  872. def make_next(i):
  873. def gen(timeout):
  874. while True:
  875. my_len = len(queues[i])
  876. max_len = max(len(q) for q in queues)
  877. # Yield to let other iterators that have fallen behind
  878. # process more items.
  879. if my_len < max_len:
  880. yield _NextValueNotReady()
  881. else:
  882. if len(queues[i]) == 0:
  883. try:
  884. fill_next(timeout)
  885. except StopIteration:
  886. return
  887. yield queues[i].popleft()
  888. return gen
  889. iterators = []
  890. for i in range(n):
  891. iterators.append(
  892. LocalIterator(
  893. make_next(i),
  894. self.shared_metrics,
  895. [],
  896. name=self.name + f".duplicate[{i}]",
  897. )
  898. )
  899. return iterators
  900. def union(
  901. self,
  902. *others: "LocalIterator[T]",
  903. deterministic: bool = False,
  904. round_robin_weights: List[float] = None,
  905. ) -> "LocalIterator[T]":
  906. """Return an iterator that is the union of this and the others.
  907. Args:
  908. deterministic: If deterministic=True, we alternate between
  909. reading from one iterator and the others. Otherwise we return
  910. items from iterators as they become ready.
  911. round_robin_weights: List of weights to use for round robin
  912. mode. For example, [2, 1] will cause the iterator to pull twice
  913. as many items from the first iterator as the second.
  914. [2, 1, "*"] will cause as many items to be pulled as possible
  915. from the third iterator without blocking. This overrides the
  916. deterministic flag.
  917. """
  918. for it in others:
  919. if not isinstance(it, LocalIterator):
  920. raise ValueError(f"other must be of type LocalIterator, got {type(it)}")
  921. active = []
  922. parent_iters = [self] + list(others)
  923. shared_metrics = SharedMetrics(parents=[p.shared_metrics for p in parent_iters])
  924. timeout = None if deterministic else 0
  925. if round_robin_weights:
  926. if len(round_robin_weights) != len(parent_iters):
  927. raise ValueError(
  928. "Length of round robin weights must equal number of "
  929. "iterators total."
  930. )
  931. timeouts = [0 if w == "*" else None for w in round_robin_weights]
  932. else:
  933. timeouts = [timeout] * len(parent_iters)
  934. round_robin_weights = [1] * len(parent_iters)
  935. for i, it in enumerate(parent_iters):
  936. active.append(
  937. LocalIterator(
  938. it.base_iterator,
  939. shared_metrics,
  940. it.local_transforms,
  941. timeout=timeouts[i],
  942. )
  943. )
  944. active = list(zip(round_robin_weights, active))
  945. def build_union(timeout=None):
  946. while True:
  947. for weight, it in list(active):
  948. if weight == "*":
  949. max_pull = 100 # TOOD(ekl) how to best bound this?
  950. else:
  951. max_pull = _randomized_int_cast(weight)
  952. try:
  953. for _ in range(max_pull):
  954. item = next(it)
  955. if isinstance(item, _NextValueNotReady):
  956. if timeout is not None:
  957. yield item
  958. break
  959. else:
  960. yield item
  961. except StopIteration:
  962. active.remove((weight, it))
  963. if not active:
  964. break
  965. return LocalIterator(
  966. build_union,
  967. shared_metrics,
  968. [],
  969. name=f"LocalUnion[{self}, {', '.join(map(str, others))}]",
  970. )
  971. @Deprecated
  972. class ParallelIteratorWorker(object):
  973. """Worker actor for a ParallelIterator.
  974. Actors that are passed to iter.from_actors() must subclass this interface.
  975. """
  976. def __init__(self, item_generator: Any, repeat: bool):
  977. """Create an iterator worker.
  978. Subclasses must call this init function.
  979. Args:
  980. item_generator: A Python iterable or lambda function
  981. that produces a generator when called. We allow lambda
  982. functions since the generator itself might not be serializable,
  983. but a lambda that returns it can be.
  984. repeat: Whether to loop over the iterator forever.
  985. """
  986. def make_iterator():
  987. if callable(item_generator):
  988. return item_generator()
  989. else:
  990. return item_generator
  991. if repeat:
  992. def cycle():
  993. while True:
  994. it = iter(make_iterator())
  995. if it is item_generator:
  996. raise ValueError(
  997. "Cannot iterate over {0} multiple times."
  998. + "Please pass in the base iterable or"
  999. + "lambda: {0} instead.".format(item_generator)
  1000. )
  1001. for item in it:
  1002. yield item
  1003. self.item_generator = cycle()
  1004. else:
  1005. self.item_generator = make_iterator()
  1006. self.transforms = []
  1007. self.local_it = None
  1008. self.next_ith_buffer = None
  1009. def par_iter_init(self, transforms):
  1010. """Implements ParallelIterator worker init."""
  1011. it = LocalIterator(lambda timeout: self.item_generator, SharedMetrics())
  1012. for fn in transforms:
  1013. it = fn(it)
  1014. assert it is not None, fn
  1015. self.local_it = iter(it)
  1016. def par_iter_next(self):
  1017. """Implements ParallelIterator worker item fetch."""
  1018. assert self.local_it is not None, "must call par_iter_init()"
  1019. return next(self.local_it)
  1020. def par_iter_next_batch(self, batch_ms: int):
  1021. """Batches par_iter_next."""
  1022. batch = []
  1023. if batch_ms == 0:
  1024. batch.append(self.par_iter_next())
  1025. return batch
  1026. t_end = time.time() + (0.001 * batch_ms)
  1027. while time.time() < t_end:
  1028. try:
  1029. batch.append(self.par_iter_next())
  1030. except StopIteration:
  1031. if len(batch) == 0:
  1032. raise StopIteration
  1033. else:
  1034. pass
  1035. return batch
  1036. def par_iter_slice(self, step: int, start: int):
  1037. """Iterates in increments of step starting from start."""
  1038. assert self.local_it is not None, "must call par_iter_init()"
  1039. if self.next_ith_buffer is None:
  1040. self.next_ith_buffer = collections.defaultdict(list)
  1041. index_buffer = self.next_ith_buffer[start]
  1042. if len(index_buffer) > 0:
  1043. return index_buffer.pop(0)
  1044. else:
  1045. for j in range(step):
  1046. try:
  1047. val = next(self.local_it)
  1048. self.next_ith_buffer[j].append(val)
  1049. except StopIteration:
  1050. pass
  1051. if not self.next_ith_buffer[start]:
  1052. raise StopIteration
  1053. return self.next_ith_buffer[start].pop(0)
  1054. def par_iter_slice_batch(self, step: int, start: int, batch_ms: int):
  1055. """Batches par_iter_slice."""
  1056. batch = []
  1057. if batch_ms == 0:
  1058. batch.append(self.par_iter_slice(step, start))
  1059. return batch
  1060. t_end = time.time() + (0.001 * batch_ms)
  1061. while time.time() < t_end:
  1062. try:
  1063. batch.append(self.par_iter_slice(step, start))
  1064. except StopIteration:
  1065. if len(batch) == 0:
  1066. raise StopIteration
  1067. else:
  1068. pass
  1069. return batch
  1070. def _randomized_int_cast(float_value):
  1071. base = int(float_value)
  1072. remainder = float_value - base
  1073. if random.random() < remainder:
  1074. base += 1
  1075. return base
  1076. class _NextValueNotReady(Exception):
  1077. """Indicates that a local iterator has no value currently available.
  1078. This is used internally to implement the union() of multiple blocking
  1079. local generators."""
  1080. pass
  1081. class _ActorSet(object):
  1082. """Helper class that represents a set of actors and transforms."""
  1083. def __init__(
  1084. self,
  1085. actors: List["ray.actor.ActorHandle"],
  1086. transforms: List[Callable[["LocalIterator"], "LocalIterator"]],
  1087. ):
  1088. self.actors = actors
  1089. self.transforms = transforms
  1090. def init_actors(self):
  1091. ray.get([a.par_iter_init.remote(self.transforms) for a in self.actors])
  1092. def with_transform(self, fn):
  1093. return _ActorSet(self.actors, self.transforms + [fn])