util.py 859 B

12345678910111213141516171819202122232425262728
  1. from typing import Iterable
  2. from ray.data.block import Block
  3. def _iter_sliced_blocks(
  4. blocks: Iterable[Block], per_task_row_limit: int
  5. ) -> Iterable[Block]:
  6. """Iterate over blocks, accumulating rows up to the per-task row limit."""
  7. rows_read = 0
  8. for block in blocks:
  9. if rows_read >= per_task_row_limit:
  10. break
  11. from ray.data.block import BlockAccessor
  12. accessor = BlockAccessor.for_block(block)
  13. block_rows = accessor.num_rows()
  14. if rows_read + block_rows <= per_task_row_limit:
  15. yield block
  16. rows_read += block_rows
  17. else:
  18. # Slice the block to meet the limit exactly
  19. remaining_rows = per_task_row_limit - rows_read
  20. sliced_block = accessor.slice(0, remaining_rows, copy=True)
  21. yield sliced_block
  22. break