equalize.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. from typing import List, Tuple
  2. from ray.data._internal.execution.interfaces import RefBundle
  3. from ray.data._internal.split import _calculate_blocks_rows, _split_at_indices
  4. from ray.data.block import (
  5. Block,
  6. BlockMetadata,
  7. BlockPartition,
  8. _take_first_non_empty_schema,
  9. )
  10. from ray.types import ObjectRef
  11. def _equalize(
  12. per_split_bundles: List[RefBundle],
  13. owned_by_consumer: bool,
  14. ) -> List[RefBundle]:
  15. """Equalize split ref bundles into equal number of rows.
  16. Args:
  17. per_split_bundles: ref bundles to equalize.
  18. Returns:
  19. the equalized ref bundles.
  20. """
  21. if len(per_split_bundles) == 0:
  22. return per_split_bundles
  23. per_split_blocks_with_metadata = [bundle.blocks for bundle in per_split_bundles]
  24. per_split_num_rows: List[List[int]] = [
  25. _calculate_blocks_rows(split) for split in per_split_blocks_with_metadata
  26. ]
  27. total_rows = sum([sum(blocks_rows) for blocks_rows in per_split_num_rows])
  28. target_split_size = total_rows // len(per_split_blocks_with_metadata)
  29. # phase 1: shave the current splits by dropping blocks (into leftovers)
  30. # and calculate num rows needed to the meet target.
  31. shaved_splits, per_split_needed_rows, leftovers = _shave_all_splits(
  32. per_split_blocks_with_metadata, per_split_num_rows, target_split_size
  33. )
  34. # validate invariants
  35. for shaved_split, split_needed_row in zip(shaved_splits, per_split_needed_rows):
  36. num_shaved_rows = sum([meta.num_rows for _, meta in shaved_split])
  37. assert num_shaved_rows <= target_split_size
  38. assert num_shaved_rows + split_needed_row == target_split_size
  39. # phase 2: based on the num rows needed for each shaved split, split the leftovers
  40. # in the shape that exactly matches the rows needed.
  41. schema = _take_first_non_empty_schema(bundle.schema for bundle in per_split_bundles)
  42. leftover_bundle = RefBundle(leftovers, owns_blocks=owned_by_consumer, schema=schema)
  43. leftover_splits = _split_leftovers(leftover_bundle, per_split_needed_rows)
  44. # phase 3: merge the shaved_splits and leftoever splits and return.
  45. for i, leftover_split in enumerate(leftover_splits):
  46. shaved_splits[i].extend(leftover_split)
  47. # validate invariants.
  48. num_shaved_rows = sum([meta.num_rows for _, meta in shaved_splits[i]])
  49. assert num_shaved_rows == target_split_size
  50. # Compose the result back to RefBundle
  51. equalized_ref_bundles: List[RefBundle] = []
  52. for split in shaved_splits:
  53. equalized_ref_bundles.append(
  54. RefBundle(split, owns_blocks=owned_by_consumer, schema=schema)
  55. )
  56. return equalized_ref_bundles
  57. def _shave_one_split(
  58. split: BlockPartition, num_rows_per_block: List[int], target_size: int
  59. ) -> Tuple[BlockPartition, int, BlockPartition]:
  60. """Shave a block list to the target size.
  61. Args:
  62. split: the block list to shave.
  63. num_rows_per_block: num rows for each block in the list.
  64. target_size: the upper bound target size of the shaved list.
  65. Returns:
  66. A tuple of:
  67. - shaved block list.
  68. - num of rows needed for the block list to meet the target size.
  69. - leftover blocks.
  70. """
  71. # iterates through the blocks from the input list and
  72. shaved = []
  73. leftovers = []
  74. shaved_rows = 0
  75. for block_with_meta, block_rows in zip(split, num_rows_per_block):
  76. if block_rows + shaved_rows <= target_size:
  77. shaved.append(block_with_meta)
  78. shaved_rows += block_rows
  79. else:
  80. leftovers.append(block_with_meta)
  81. num_rows_needed = target_size - shaved_rows
  82. return shaved, num_rows_needed, leftovers
  83. def _shave_all_splits(
  84. input_splits: List[BlockPartition],
  85. per_split_num_rows: List[List[int]],
  86. target_size: int,
  87. ) -> Tuple[List[BlockPartition], List[int], BlockPartition]:
  88. """Shave all block list to the target size.
  89. Args:
  90. input_splits: all block list to shave.
  91. input_splits: num rows (per block) for each block list.
  92. target_size: the upper bound target size of the shaved lists.
  93. Returns:
  94. A tuple of:
  95. - all shaved block list.
  96. - num of rows needed for the block list to meet the target size.
  97. - leftover blocks.
  98. """
  99. shaved_splits = []
  100. per_split_needed_rows = []
  101. leftovers = []
  102. for split, num_rows_per_block in zip(input_splits, per_split_num_rows):
  103. shaved, num_rows_needed, _leftovers = _shave_one_split(
  104. split, num_rows_per_block, target_size
  105. )
  106. shaved_splits.append(shaved)
  107. per_split_needed_rows.append(num_rows_needed)
  108. leftovers.extend(_leftovers)
  109. return shaved_splits, per_split_needed_rows, leftovers
  110. def _split_leftovers(
  111. leftovers: RefBundle, per_split_needed_rows: List[int]
  112. ) -> List[BlockPartition]:
  113. """Split leftover blocks by the num of rows needed."""
  114. num_splits = len(per_split_needed_rows)
  115. split_indices = []
  116. prev = 0
  117. for i, num_rows_needed in enumerate(per_split_needed_rows):
  118. split_indices.append(prev + num_rows_needed)
  119. prev = split_indices[i]
  120. split_result: Tuple[
  121. List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
  122. ] = _split_at_indices(
  123. leftovers.blocks,
  124. split_indices,
  125. leftovers.owns_blocks,
  126. )
  127. return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
  128. :num_splits
  129. ]