| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- from typing import List, Optional, Tuple
- from ray.data._internal.execution.interfaces import (
- AllToAllTransformFn,
- RefBundle,
- TaskContext,
- )
- from ray.data._internal.execution.interfaces.transform_fn import (
- AllToAllTransformFnResult,
- )
- from ray.data._internal.execution.operators.map_transformer import MapTransformer
- from ray.data._internal.planner.exchange.pull_based_shuffle_task_scheduler import (
- PullBasedShuffleTaskScheduler,
- )
- from ray.data._internal.planner.exchange.push_based_shuffle_task_scheduler import (
- PushBasedShuffleTaskScheduler,
- )
- from ray.data._internal.planner.exchange.shuffle_task_spec import ShuffleTaskSpec
- from ray.data._internal.planner.exchange.split_repartition_task_scheduler import (
- SplitRepartitionTaskScheduler,
- )
- from ray.data._internal.stats import StatsDict
- from ray.data.context import DataContext, ShuffleStrategy
- def generate_repartition_fn(
- num_outputs: int,
- shuffle: bool,
- data_context: DataContext,
- _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
- ) -> AllToAllTransformFn:
- """Generate function to partition each records of blocks."""
- def shuffle_repartition_fn(
- refs: List[RefBundle],
- ctx: TaskContext,
- ) -> Tuple[List[RefBundle], StatsDict]:
- # If map_transformer is specified (e.g. from fusing
- # MapOperator->AllToAllOperator), we pass a map function which
- # is applied to each block before shuffling.
- map_transformer: Optional["MapTransformer"] = ctx.upstream_map_transformer
- upstream_map_fn = None
- if map_transformer:
- # NOTE: We override target max-block sizing of the previous
- # transformation to avoid unnecessary block shaping (if any)
- map_transformer.override_target_max_block_size(None)
- def upstream_map_fn(blocks):
- return map_transformer.apply_transform(blocks, ctx)
- shuffle_spec = ShuffleTaskSpec(
- target_shuffle_max_block_size=(
- ctx.target_max_block_size_override or data_context.target_max_block_size
- ),
- random_shuffle=False,
- upstream_map_fn=upstream_map_fn,
- )
- if data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED:
- scheduler = PushBasedShuffleTaskScheduler(shuffle_spec)
- else:
- scheduler = PullBasedShuffleTaskScheduler(shuffle_spec)
- return scheduler.execute(
- refs,
- num_outputs,
- ctx,
- _debug_limit_execution_to_num_blocks=(
- _debug_limit_shuffle_execution_to_num_blocks
- ),
- )
- def split_repartition_fn(
- refs: List[RefBundle],
- ctx: TaskContext,
- ) -> AllToAllTransformFnResult:
- shuffle_spec = ShuffleTaskSpec(
- target_shuffle_max_block_size=(
- ctx.target_max_block_size_override or data_context.target_max_block_size
- ),
- random_shuffle=False,
- )
- scheduler = SplitRepartitionTaskScheduler(shuffle_spec)
- return scheduler.execute(refs, num_outputs, ctx)
- if shuffle:
- return shuffle_repartition_fn
- return split_repartition_fn
|