streaming_repartition.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from collections import deque
  2. from typing import Deque, List, Tuple
  3. from ray.data._internal.execution.interfaces import RefBundle
  4. from ray.data._internal.execution.operators.map_operator import BaseRefBundler
  5. """Streaming repartition builds fixed-size outputs from a stream of inputs.
  6. We construct batches here to produce exactly sized outputs from arbitrary [start, end) slices across input blocks.
  7. The task builder submits a map task only after the total number of rows accumulated across pending blocks reaches
  8. target num rows (except during the final flush, which may emit a smaller tail block). This allows us to create
  9. target-sized batches without materializing entire large blocks on the driver.
  10. Detailed Implementation:
  11. 1. When a new bundle arrives, buffer it in the pending list.
  12. 2. Whenever the total number of rows in the pending bundles reaches the target row count, try to build a ready bundle.
  13. 3. Determine the slice needed from the final bundle so the ready bundle holds an exact multiple of the target rows,
  14. and add the remaining bundle to the pending bundles for the next iteration.
  15. 4. Submit that ready bundle to a remote map task; the task slices each block according to the slice metadata stored
  16. in the RefBundle (the bundle now contains n × target rows for n ≥ 1).
  17. 5. We configured the `OutputBlockSizeOption.target_num_rows_per_block` to the target number of rows per block in
  18. plan_streaming_repartition_op so the output buffer further splits the n × target rows into n blocks of exactly
  19. the target size.
  20. 6. Once upstream input is exhausted, flush any leftover pending bundles and repeat steps 1‑5 for the tail.
  21. 7. The resulting blocks have lengths `[target, …, target, (total_rows % target)]`; ordering isn’t guaranteed, but the
  22. remainder block should appear near the end.
  23. """
  24. class StreamingRepartitionRefBundler(BaseRefBundler):
  25. """Incrementally builds task inputs to produce multiples of target-sized outputs."""
  26. def __init__(self, target_num_rows_per_block: int):
  27. assert (
  28. target_num_rows_per_block > 0
  29. ), "target_num_rows_per_block must be positive for streaming repartition."
  30. self._target_num_rows = target_num_rows_per_block
  31. self._pending_bundles: Deque[RefBundle] = deque()
  32. self._ready_bundles: Deque[RefBundle] = deque()
  33. self._consumed_input_bundles: List[RefBundle] = []
  34. self._total_pending_rows = 0
  35. def _try_build_ready_bundle(self, flush_remaining: bool = False):
  36. if self._total_pending_rows >= self._target_num_rows:
  37. rows_needed_from_last_bundle = (
  38. self._pending_bundles[-1].num_rows()
  39. - self._total_pending_rows % self._target_num_rows
  40. )
  41. assert rows_needed_from_last_bundle >= 0 # This will never be negative
  42. pending_bundles = list(self._pending_bundles)
  43. remaining_bundle = None
  44. if (
  45. rows_needed_from_last_bundle > 0
  46. and rows_needed_from_last_bundle < pending_bundles[-1].num_rows()
  47. ):
  48. last_bundle = pending_bundles.pop()
  49. sliced_bundle, remaining_bundle = last_bundle.slice(
  50. rows_needed_from_last_bundle
  51. )
  52. pending_bundles.append(sliced_bundle)
  53. self._ready_bundles.append(RefBundle.merge_ref_bundles(pending_bundles))
  54. self._pending_bundles.clear()
  55. self._total_pending_rows = 0
  56. if remaining_bundle and remaining_bundle.num_rows() > 0:
  57. self._pending_bundles.append(remaining_bundle)
  58. self._total_pending_rows += remaining_bundle.num_rows()
  59. if flush_remaining and len(self._pending_bundles) > 0:
  60. self._ready_bundles.append(
  61. RefBundle.merge_ref_bundles(self._pending_bundles)
  62. )
  63. self._pending_bundles.clear()
  64. self._total_pending_rows = 0
  65. def add_bundle(self, ref_bundle: RefBundle):
  66. self._total_pending_rows += ref_bundle.num_rows()
  67. self._pending_bundles.append(ref_bundle)
  68. self._try_build_ready_bundle()
  69. self._consumed_input_bundles.append(ref_bundle)
  70. def has_bundle(self) -> bool:
  71. return len(self._ready_bundles) > 0
  72. def get_next_bundle(
  73. self,
  74. ) -> Tuple[List[RefBundle], RefBundle]:
  75. consumed_input_bundles = self._consumed_input_bundles
  76. self._consumed_input_bundles = []
  77. return consumed_input_bundles, self._ready_bundles.popleft()
  78. def done_adding_bundles(self):
  79. if len(self._pending_bundles) > 0:
  80. self._try_build_ready_bundle(flush_remaining=True)
  81. def num_blocks(self):
  82. return sum(len(bundle) for bundle in self._pending_bundles) + sum(
  83. len(bundle) for bundle in self._ready_bundles
  84. )
  85. def size_bytes(self) -> int:
  86. return sum(bundle.size_bytes() for bundle in self._pending_bundles) + sum(
  87. bundle.size_bytes() for bundle in self._ready_bundles
  88. )