import itertools import logging from typing import Iterable, List, Tuple, Union import ray from ray.data._internal.memory_tracing import trace_deallocation from ray.data._internal.remote_fn import cached_remote_fn from ray.data.block import ( Block, BlockAccessor, BlockExecStats, BlockMetadata, BlockPartition, ) from ray.types import ObjectRef logger = logging.getLogger(__name__) def _calculate_blocks_rows( blocks_with_metadata: BlockPartition, ) -> List[int]: """Calculate the number of rows for a list of blocks with metadata.""" get_num_rows = cached_remote_fn(_get_num_rows) block_rows = [] for block, metadata in blocks_with_metadata: if metadata.num_rows is None: # Need to fetch number of rows. num_rows = ray.get(get_num_rows.remote(block)) metadata.num_rows = num_rows else: num_rows = metadata.num_rows block_rows.append(num_rows) return block_rows def _generate_valid_indices( num_rows_per_block: List[int], split_indices: List[int], ) -> List[int]: """Generate valid split indices by apply min(index, total_num_rows) to every index.""" total_rows = sum(num_rows_per_block) return [min(index, total_rows) for index in split_indices] def _generate_per_block_split_indices( num_rows_per_block: List[int], split_indices: List[int], ) -> List[List[int]]: """Given num rows per block and valid split indices, generate per block split indices. Args: num_rows_per_block: num of rows per block. split_indices: The (global) indices at which to split the blocks. Returns: Per block split indices indicates each input block's split point(s). """ # for each split index, we iterate though the currnet input block # to see if the index falls into this block. if the index # falls into this block, we push it back to the current block's # split indices. Otherwise, we move on to the next block. per_block_split_indices = [] current_input_block_id = 0 current_block_split_indices = [] current_block_global_offset = 0 current_index_id = 0 while current_index_id < len(split_indices): split_index = split_indices[current_index_id] current_block_row = num_rows_per_block[current_input_block_id] if split_index - current_block_global_offset <= current_block_row: current_block_split_indices.append( split_index - current_block_global_offset ) current_index_id += 1 continue per_block_split_indices.append(current_block_split_indices) current_block_split_indices = [] current_block_global_offset += num_rows_per_block[current_input_block_id] current_input_block_id += 1 # we might finished all the indices but there are still blocks left, also # current_block_split_indices might not be added yet. while len(per_block_split_indices) < len(num_rows_per_block): per_block_split_indices.append(current_block_split_indices) current_block_split_indices = [] return per_block_split_indices def _split_single_block( block_id: int, block: Block, meta: BlockMetadata, split_indices: List[int], ) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]: """Split the provided block at the given indices. Args: block_id: the id of this block in the block list. block: block to be split. meta: metadata of the block, we expect meta.num is valid. split_indices: the indices where the block should be split. Returns: returns block_id, split blocks metadata, and a list of blocks in the following form. We return blocks in this way so that the owner of blocks could be the caller(driver) instead of worker itself. Tuple(block_id, split_blocks_meta), block0, block1 ... """ split_meta = [] split_blocks = [] block_accessor = BlockAccessor.for_block(block) prev_index = 0 # append one more entry at the last so we don't # need handle empty edge case. split_indices.append(meta.num_rows) for index in split_indices: logger.debug(f"slicing block {prev_index}:{index}") stats = BlockExecStats.builder() split_block = block_accessor.slice(prev_index, index) accessor = BlockAccessor.for_block(split_block) _meta = BlockMetadata( num_rows=accessor.num_rows(), size_bytes=accessor.size_bytes(), input_files=meta.input_files, exec_stats=stats.build(), ) split_meta.append(_meta) split_blocks.append(split_block) prev_index = index results = [(block_id, split_meta)] results.extend(split_blocks) return tuple(results) def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]: """drop split indices that creates empty block split. This could happen when there are duplicated indices, or index equal to 0 (start of the block) or num_block_rows (end of the block). """ prev_index = -1 optimized_indices = [] for index in block_split_indices: if index == 0 or index == num_rows: continue if index == prev_index: continue optimized_indices.append(index) prev_index = index return optimized_indices def _split_all_blocks( blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]], per_block_split_indices: List[List[int]], owned_by_consumer: bool, ) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]: """Split all the input blocks based on the split indices""" split_single_block = cached_remote_fn(_split_single_block) all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata) per_block_split_metadata_futures = [] per_block_split_block_refs = [] # tracking splitted blocks for gc. blocks_splitted = [] for block_id, block_split_indices in enumerate(per_block_split_indices): (block_ref, meta) = blocks_with_metadata[block_id] block_row = meta.num_rows block_split_indices = _drop_empty_block_split(block_split_indices, block_row) if len(block_split_indices) == 0: # optimization: if no split is needed, we just need to add it to the # result all_blocks_split_results[block_id] = [(block_ref, meta)] else: # otherwise call split remote function. object_refs = split_single_block.options( scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices) ).remote( block_id, block_ref, meta, block_split_indices, ) per_block_split_metadata_futures.append(object_refs[0]) per_block_split_block_refs.append(object_refs[1:]) blocks_splitted.append(block_ref) if per_block_split_metadata_futures: # only get metadata. per_block_split_metadata = ray.get(per_block_split_metadata_futures) for (block_id, meta), block_refs in zip( per_block_split_metadata, per_block_split_block_refs ): assert len(meta) == len(block_refs) all_blocks_split_results[block_id] = zip(block_refs, meta) # We make a copy for the blocks that have been splitted, so the input blocks # can be cleared if they are owned by consumer (consumer-owned blocks will # only be consumed by the owner). if owned_by_consumer: for b in blocks_splitted: trace_deallocation(b, "split._split_all_blocks") else: for b in blocks_splitted: trace_deallocation(b, "split._split_all_blocks", free=False) return itertools.chain.from_iterable(all_blocks_split_results) def _generate_global_split_results( all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]], global_split_sizes: List[int], ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]: """Reassemble per block's split result into final split result.""" result_blocks = [] result_metas = [] current_blocks = [] current_meta = [] current_split_size = 0 current_split_id = 0 while current_split_id < len(global_split_sizes): if current_split_size >= global_split_sizes[current_split_id]: assert current_split_size == global_split_sizes[current_split_id] result_blocks.append(current_blocks) result_metas.append(current_meta) current_blocks = [] current_meta = [] current_split_size = 0 current_split_id += 1 else: (block_ref, meta) = next(all_blocks_split_results) current_blocks.append(block_ref) current_meta.append(meta) current_split_size += meta.num_rows return result_blocks, result_metas def _split_at_indices( blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]], indices: List[int], owned_by_consumer: bool = True, block_rows: List[int] = None, ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]: """Split blocks at the provided indices. Args: blocks_with_metadata: Block futures to split, including the associated metadata. indices: The (global) indices at which to split the blocks. owned_by_consumer: Whether the provided blocks are owned by the consumer. block_rows: The number of rows for each block, in case it has already been computed. Returns: The block split futures and their metadata. If an index split is empty, the corresponding block split will be empty . """ # We implement the split in 3 phases. # phase 1: calculate the per block split indices. blocks_with_metadata = list(blocks_with_metadata) if len(blocks_with_metadata) == 0: return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1)) if block_rows is None: block_rows = _calculate_blocks_rows(blocks_with_metadata) valid_indices = _generate_valid_indices(block_rows, indices) per_block_split_indices: List[List[int]] = _generate_per_block_split_indices( block_rows, valid_indices ) # phase 2: split each block based on the indices from previous step. all_blocks_split_results: Iterable[ Tuple[ObjectRef[Block], BlockMetadata] ] = _split_all_blocks( blocks_with_metadata, per_block_split_indices, owned_by_consumer ) # phase 3: generate the final split. # first calculate the size for each split. helper = [0] + valid_indices + [sum(block_rows)] split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))] return _generate_global_split_results(all_blocks_split_results, split_sizes) def _get_num_rows(block: Block) -> int: """Get the number of rows contained in the provided block.""" return BlockAccessor.for_block(block).num_rows()