| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- import itertools
- import logging
- from typing import Iterable, List, Tuple, Union
- import ray
- from ray.data._internal.memory_tracing import trace_deallocation
- from ray.data._internal.remote_fn import cached_remote_fn
- from ray.data.block import (
- Block,
- BlockAccessor,
- BlockExecStats,
- BlockMetadata,
- BlockPartition,
- )
- from ray.types import ObjectRef
- logger = logging.getLogger(__name__)
- def _calculate_blocks_rows(
- blocks_with_metadata: BlockPartition,
- ) -> List[int]:
- """Calculate the number of rows for a list of blocks with metadata."""
- get_num_rows = cached_remote_fn(_get_num_rows)
- block_rows = []
- for block, metadata in blocks_with_metadata:
- if metadata.num_rows is None:
- # Need to fetch number of rows.
- num_rows = ray.get(get_num_rows.remote(block))
- metadata.num_rows = num_rows
- else:
- num_rows = metadata.num_rows
- block_rows.append(num_rows)
- return block_rows
- def _generate_valid_indices(
- num_rows_per_block: List[int],
- split_indices: List[int],
- ) -> List[int]:
- """Generate valid split indices by apply min(index, total_num_rows)
- to every index."""
- total_rows = sum(num_rows_per_block)
- return [min(index, total_rows) for index in split_indices]
- def _generate_per_block_split_indices(
- num_rows_per_block: List[int],
- split_indices: List[int],
- ) -> List[List[int]]:
- """Given num rows per block and valid split indices, generate per block split indices.
- Args:
- num_rows_per_block: num of rows per block.
- split_indices: The (global) indices at which to split the blocks.
- Returns:
- Per block split indices indicates each input block's split point(s).
- """
- # for each split index, we iterate though the currnet input block
- # to see if the index falls into this block. if the index
- # falls into this block, we push it back to the current block's
- # split indices. Otherwise, we move on to the next block.
- per_block_split_indices = []
- current_input_block_id = 0
- current_block_split_indices = []
- current_block_global_offset = 0
- current_index_id = 0
- while current_index_id < len(split_indices):
- split_index = split_indices[current_index_id]
- current_block_row = num_rows_per_block[current_input_block_id]
- if split_index - current_block_global_offset <= current_block_row:
- current_block_split_indices.append(
- split_index - current_block_global_offset
- )
- current_index_id += 1
- continue
- per_block_split_indices.append(current_block_split_indices)
- current_block_split_indices = []
- current_block_global_offset += num_rows_per_block[current_input_block_id]
- current_input_block_id += 1
- # we might finished all the indices but there are still blocks left, also
- # current_block_split_indices might not be added yet.
- while len(per_block_split_indices) < len(num_rows_per_block):
- per_block_split_indices.append(current_block_split_indices)
- current_block_split_indices = []
- return per_block_split_indices
- def _split_single_block(
- block_id: int,
- block: Block,
- meta: BlockMetadata,
- split_indices: List[int],
- ) -> Tuple[Union[Tuple[int, List[BlockMetadata]], Block], ...]:
- """Split the provided block at the given indices.
- Args:
- block_id: the id of this block in the block list.
- block: block to be split.
- meta: metadata of the block, we expect meta.num is valid.
- split_indices: the indices where the block should be split.
- Returns:
- returns block_id, split blocks metadata, and a list of blocks
- in the following form. We return blocks in this way
- so that the owner of blocks could be the caller(driver)
- instead of worker itself.
- Tuple(block_id, split_blocks_meta), block0, block1 ...
- """
- split_meta = []
- split_blocks = []
- block_accessor = BlockAccessor.for_block(block)
- prev_index = 0
- # append one more entry at the last so we don't
- # need handle empty edge case.
- split_indices.append(meta.num_rows)
- for index in split_indices:
- logger.debug(f"slicing block {prev_index}:{index}")
- stats = BlockExecStats.builder()
- split_block = block_accessor.slice(prev_index, index)
- accessor = BlockAccessor.for_block(split_block)
- _meta = BlockMetadata(
- num_rows=accessor.num_rows(),
- size_bytes=accessor.size_bytes(),
- input_files=meta.input_files,
- exec_stats=stats.build(),
- )
- split_meta.append(_meta)
- split_blocks.append(split_block)
- prev_index = index
- results = [(block_id, split_meta)]
- results.extend(split_blocks)
- return tuple(results)
- def _drop_empty_block_split(block_split_indices: List[int], num_rows: int) -> List[int]:
- """drop split indices that creates empty block split. This could happen when there
- are duplicated indices, or index equal to 0 (start of the block) or num_block_rows
- (end of the block).
- """
- prev_index = -1
- optimized_indices = []
- for index in block_split_indices:
- if index == 0 or index == num_rows:
- continue
- if index == prev_index:
- continue
- optimized_indices.append(index)
- prev_index = index
- return optimized_indices
- def _split_all_blocks(
- blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
- per_block_split_indices: List[List[int]],
- owned_by_consumer: bool,
- ) -> Iterable[Tuple[ObjectRef[Block], BlockMetadata]]:
- """Split all the input blocks based on the split indices"""
- split_single_block = cached_remote_fn(_split_single_block)
- all_blocks_split_results: List[BlockPartition] = [None] * len(blocks_with_metadata)
- per_block_split_metadata_futures = []
- per_block_split_block_refs = []
- # tracking splitted blocks for gc.
- blocks_splitted = []
- for block_id, block_split_indices in enumerate(per_block_split_indices):
- (block_ref, meta) = blocks_with_metadata[block_id]
- block_row = meta.num_rows
- block_split_indices = _drop_empty_block_split(block_split_indices, block_row)
- if len(block_split_indices) == 0:
- # optimization: if no split is needed, we just need to add it to the
- # result
- all_blocks_split_results[block_id] = [(block_ref, meta)]
- else:
- # otherwise call split remote function.
- object_refs = split_single_block.options(
- scheduling_strategy="SPREAD", num_returns=2 + len(block_split_indices)
- ).remote(
- block_id,
- block_ref,
- meta,
- block_split_indices,
- )
- per_block_split_metadata_futures.append(object_refs[0])
- per_block_split_block_refs.append(object_refs[1:])
- blocks_splitted.append(block_ref)
- if per_block_split_metadata_futures:
- # only get metadata.
- per_block_split_metadata = ray.get(per_block_split_metadata_futures)
- for (block_id, meta), block_refs in zip(
- per_block_split_metadata, per_block_split_block_refs
- ):
- assert len(meta) == len(block_refs)
- all_blocks_split_results[block_id] = zip(block_refs, meta)
- # We make a copy for the blocks that have been splitted, so the input blocks
- # can be cleared if they are owned by consumer (consumer-owned blocks will
- # only be consumed by the owner).
- if owned_by_consumer:
- for b in blocks_splitted:
- trace_deallocation(b, "split._split_all_blocks")
- else:
- for b in blocks_splitted:
- trace_deallocation(b, "split._split_all_blocks", free=False)
- return itertools.chain.from_iterable(all_blocks_split_results)
- def _generate_global_split_results(
- all_blocks_split_results: Iterable[Tuple[ObjectRef[Block], BlockMetadata]],
- global_split_sizes: List[int],
- ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
- """Reassemble per block's split result into final split result."""
- result_blocks = []
- result_metas = []
- current_blocks = []
- current_meta = []
- current_split_size = 0
- current_split_id = 0
- while current_split_id < len(global_split_sizes):
- if current_split_size >= global_split_sizes[current_split_id]:
- assert current_split_size == global_split_sizes[current_split_id]
- result_blocks.append(current_blocks)
- result_metas.append(current_meta)
- current_blocks = []
- current_meta = []
- current_split_size = 0
- current_split_id += 1
- else:
- (block_ref, meta) = next(all_blocks_split_results)
- current_blocks.append(block_ref)
- current_meta.append(meta)
- current_split_size += meta.num_rows
- return result_blocks, result_metas
- def _split_at_indices(
- blocks_with_metadata: List[Tuple[ObjectRef[Block], BlockMetadata]],
- indices: List[int],
- owned_by_consumer: bool = True,
- block_rows: List[int] = None,
- ) -> Tuple[List[List[ObjectRef[Block]]], List[List[BlockMetadata]]]:
- """Split blocks at the provided indices.
- Args:
- blocks_with_metadata: Block futures to split, including the associated metadata.
- indices: The (global) indices at which to split the blocks.
- owned_by_consumer: Whether the provided blocks are owned by the consumer.
- block_rows: The number of rows for each block, in case it has already been
- computed.
- Returns:
- The block split futures and their metadata. If an index split is empty, the
- corresponding block split will be empty .
- """
- # We implement the split in 3 phases.
- # phase 1: calculate the per block split indices.
- blocks_with_metadata = list(blocks_with_metadata)
- if len(blocks_with_metadata) == 0:
- return ([[]] * (len(indices) + 1), [[]] * (len(indices) + 1))
- if block_rows is None:
- block_rows = _calculate_blocks_rows(blocks_with_metadata)
- valid_indices = _generate_valid_indices(block_rows, indices)
- per_block_split_indices: List[List[int]] = _generate_per_block_split_indices(
- block_rows, valid_indices
- )
- # phase 2: split each block based on the indices from previous step.
- all_blocks_split_results: Iterable[
- Tuple[ObjectRef[Block], BlockMetadata]
- ] = _split_all_blocks(
- blocks_with_metadata, per_block_split_indices, owned_by_consumer
- )
- # phase 3: generate the final split.
- # first calculate the size for each split.
- helper = [0] + valid_indices + [sum(block_rows)]
- split_sizes = [helper[i] - helper[i - 1] for i in range(1, len(helper))]
- return _generate_global_split_results(all_blocks_split_results, split_sizes)
- def _get_num_rows(block: Block) -> int:
- """Get the number of rows contained in the provided block."""
- return BlockAccessor.for_block(block).num_rows()
|