| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- from typing import List, Tuple
- from ray.data._internal.execution.interfaces import RefBundle
- from ray.data._internal.split import _calculate_blocks_rows, _split_at_indices
- from ray.data.block import (
- Block,
- BlockMetadata,
- BlockPartition,
- _take_first_non_empty_schema,
- )
- from ray.types import ObjectRef
- def _equalize(
- per_split_bundles: List[RefBundle],
- owned_by_consumer: bool,
- ) -> List[RefBundle]:
- """Equalize split ref bundles into equal number of rows.
- Args:
- per_split_bundles: ref bundles to equalize.
- Returns:
- the equalized ref bundles.
- """
- if len(per_split_bundles) == 0:
- return per_split_bundles
- per_split_blocks_with_metadata = [bundle.blocks for bundle in per_split_bundles]
- per_split_num_rows: List[List[int]] = [
- _calculate_blocks_rows(split) for split in per_split_blocks_with_metadata
- ]
- total_rows = sum([sum(blocks_rows) for blocks_rows in per_split_num_rows])
- target_split_size = total_rows // len(per_split_blocks_with_metadata)
- # phase 1: shave the current splits by dropping blocks (into leftovers)
- # and calculate num rows needed to the meet target.
- shaved_splits, per_split_needed_rows, leftovers = _shave_all_splits(
- per_split_blocks_with_metadata, per_split_num_rows, target_split_size
- )
- # validate invariants
- for shaved_split, split_needed_row in zip(shaved_splits, per_split_needed_rows):
- num_shaved_rows = sum([meta.num_rows for _, meta in shaved_split])
- assert num_shaved_rows <= target_split_size
- assert num_shaved_rows + split_needed_row == target_split_size
- # phase 2: based on the num rows needed for each shaved split, split the leftovers
- # in the shape that exactly matches the rows needed.
- schema = _take_first_non_empty_schema(bundle.schema for bundle in per_split_bundles)
- leftover_bundle = RefBundle(leftovers, owns_blocks=owned_by_consumer, schema=schema)
- leftover_splits = _split_leftovers(leftover_bundle, per_split_needed_rows)
- # phase 3: merge the shaved_splits and leftoever splits and return.
- for i, leftover_split in enumerate(leftover_splits):
- shaved_splits[i].extend(leftover_split)
- # validate invariants.
- num_shaved_rows = sum([meta.num_rows for _, meta in shaved_splits[i]])
- assert num_shaved_rows == target_split_size
- # Compose the result back to RefBundle
- equalized_ref_bundles: List[RefBundle] = []
- for split in shaved_splits:
- equalized_ref_bundles.append(
- RefBundle(split, owns_blocks=owned_by_consumer, schema=schema)
- )
- return equalized_ref_bundles
- def _shave_one_split(
- split: BlockPartition, num_rows_per_block: List[int], target_size: int
- ) -> Tuple[BlockPartition, int, BlockPartition]:
- """Shave a block list to the target size.
- Args:
- split: the block list to shave.
- num_rows_per_block: num rows for each block in the list.
- target_size: the upper bound target size of the shaved list.
- Returns:
- A tuple of:
- - shaved block list.
- - num of rows needed for the block list to meet the target size.
- - leftover blocks.
- """
- # iterates through the blocks from the input list and
- shaved = []
- leftovers = []
- shaved_rows = 0
- for block_with_meta, block_rows in zip(split, num_rows_per_block):
- if block_rows + shaved_rows <= target_size:
- shaved.append(block_with_meta)
- shaved_rows += block_rows
- else:
- leftovers.append(block_with_meta)
- num_rows_needed = target_size - shaved_rows
- return shaved, num_rows_needed, leftovers
- def _shave_all_splits(
- input_splits: List[BlockPartition],
- per_split_num_rows: List[List[int]],
- target_size: int,
- ) -> Tuple[List[BlockPartition], List[int], BlockPartition]:
- """Shave all block list to the target size.
- Args:
- input_splits: all block list to shave.
- input_splits: num rows (per block) for each block list.
- target_size: the upper bound target size of the shaved lists.
- Returns:
- A tuple of:
- - all shaved block list.
- - num of rows needed for the block list to meet the target size.
- - leftover blocks.
- """
- shaved_splits = []
- per_split_needed_rows = []
- leftovers = []
- for split, num_rows_per_block in zip(input_splits, per_split_num_rows):
- shaved, num_rows_needed, _leftovers = _shave_one_split(
- split, num_rows_per_block, target_size
- )
- shaved_splits.append(shaved)
- per_split_needed_rows.append(num_rows_needed)
- leftovers.extend(_leftovers)
- return shaved_splits, per_split_needed_rows, leftovers
- def _split_leftovers(
- leftovers: RefBundle, per_split_needed_rows: List[int]
- ) -> List[BlockPartition]:
- """Split leftover blocks by the num of rows needed."""
- num_splits = len(per_split_needed_rows)
- split_indices = []
- prev = 0
- for i, num_rows_needed in enumerate(per_split_needed_rows):
- split_indices.append(prev + num_rows_needed)
- prev = split_indices[i]
- split_result: Tuple[
- List[List[ObjectRef[Block]]], List[List[BlockMetadata]]
- ] = _split_at_indices(
- leftovers.blocks,
- split_indices,
- leftovers.owns_blocks,
- )
- return [list(zip(block_refs, meta)) for block_refs, meta in zip(*split_result)][
- :num_splits
- ]
|