| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- from collections import deque
- from typing import Deque, List, Tuple
- from ray.data._internal.execution.interfaces import RefBundle
- from ray.data._internal.execution.operators.map_operator import BaseRefBundler
- """Streaming repartition builds fixed-size outputs from a stream of inputs.
- We construct batches here to produce exactly sized outputs from arbitrary [start, end) slices across input blocks.
- The task builder submits a map task only after the total number of rows accumulated across pending blocks reaches
- target num rows (except during the final flush, which may emit a smaller tail block). This allows us to create
- target-sized batches without materializing entire large blocks on the driver.
- Detailed Implementation:
- 1. When a new bundle arrives, buffer it in the pending list.
- 2. Whenever the total number of rows in the pending bundles reaches the target row count, try to build a ready bundle.
- 3. Determine the slice needed from the final bundle so the ready bundle holds an exact multiple of the target rows,
- and add the remaining bundle to the pending bundles for the next iteration.
- 4. Submit that ready bundle to a remote map task; the task slices each block according to the slice metadata stored
- in the RefBundle (the bundle now contains n × target rows for n ≥ 1).
- 5. We configured the `OutputBlockSizeOption.target_num_rows_per_block` to the target number of rows per block in
- plan_streaming_repartition_op so the output buffer further splits the n × target rows into n blocks of exactly
- the target size.
- 6. Once upstream input is exhausted, flush any leftover pending bundles and repeat steps 1‑5 for the tail.
- 7. The resulting blocks have lengths `[target, …, target, (total_rows % target)]`; ordering isn’t guaranteed, but the
- remainder block should appear near the end.
- """
- class StreamingRepartitionRefBundler(BaseRefBundler):
- """Incrementally builds task inputs to produce multiples of target-sized outputs."""
- def __init__(self, target_num_rows_per_block: int):
- assert (
- target_num_rows_per_block > 0
- ), "target_num_rows_per_block must be positive for streaming repartition."
- self._target_num_rows = target_num_rows_per_block
- self._pending_bundles: Deque[RefBundle] = deque()
- self._ready_bundles: Deque[RefBundle] = deque()
- self._consumed_input_bundles: List[RefBundle] = []
- self._total_pending_rows = 0
- def _try_build_ready_bundle(self, flush_remaining: bool = False):
- if self._total_pending_rows >= self._target_num_rows:
- rows_needed_from_last_bundle = (
- self._pending_bundles[-1].num_rows()
- - self._total_pending_rows % self._target_num_rows
- )
- assert rows_needed_from_last_bundle >= 0 # This will never be negative
- pending_bundles = list(self._pending_bundles)
- remaining_bundle = None
- if (
- rows_needed_from_last_bundle > 0
- and rows_needed_from_last_bundle < pending_bundles[-1].num_rows()
- ):
- last_bundle = pending_bundles.pop()
- sliced_bundle, remaining_bundle = last_bundle.slice(
- rows_needed_from_last_bundle
- )
- pending_bundles.append(sliced_bundle)
- self._ready_bundles.append(RefBundle.merge_ref_bundles(pending_bundles))
- self._pending_bundles.clear()
- self._total_pending_rows = 0
- if remaining_bundle and remaining_bundle.num_rows() > 0:
- self._pending_bundles.append(remaining_bundle)
- self._total_pending_rows += remaining_bundle.num_rows()
- if flush_remaining and len(self._pending_bundles) > 0:
- self._ready_bundles.append(
- RefBundle.merge_ref_bundles(self._pending_bundles)
- )
- self._pending_bundles.clear()
- self._total_pending_rows = 0
- def add_bundle(self, ref_bundle: RefBundle):
- self._total_pending_rows += ref_bundle.num_rows()
- self._pending_bundles.append(ref_bundle)
- self._try_build_ready_bundle()
- self._consumed_input_bundles.append(ref_bundle)
- def has_bundle(self) -> bool:
- return len(self._ready_bundles) > 0
- def get_next_bundle(
- self,
- ) -> Tuple[List[RefBundle], RefBundle]:
- consumed_input_bundles = self._consumed_input_bundles
- self._consumed_input_bundles = []
- return consumed_input_bundles, self._ready_bundles.popleft()
- def done_adding_bundles(self):
- if len(self._pending_bundles) > 0:
- self._try_build_ready_bundle(flush_remaining=True)
- def num_blocks(self):
- return sum(len(bundle) for bundle in self._pending_bundles) + sum(
- len(bundle) for bundle in self._ready_bundles
- )
- def size_bytes(self) -> int:
- return sum(bundle.size_bytes() for bundle in self._pending_bundles) + sum(
- bundle.size_bytes() for bundle in self._ready_bundles
- )
|