from typing import List, Optional, Tuple from ray.data._internal.execution.interfaces import ( AllToAllTransformFn, RefBundle, TaskContext, ) from ray.data._internal.execution.interfaces.transform_fn import ( AllToAllTransformFnResult, ) from ray.data._internal.execution.operators.map_transformer import MapTransformer from ray.data._internal.planner.exchange.pull_based_shuffle_task_scheduler import ( PullBasedShuffleTaskScheduler, ) from ray.data._internal.planner.exchange.push_based_shuffle_task_scheduler import ( PushBasedShuffleTaskScheduler, ) from ray.data._internal.planner.exchange.shuffle_task_spec import ShuffleTaskSpec from ray.data._internal.planner.exchange.split_repartition_task_scheduler import ( SplitRepartitionTaskScheduler, ) from ray.data._internal.stats import StatsDict from ray.data.context import DataContext, ShuffleStrategy def generate_repartition_fn( num_outputs: int, shuffle: bool, data_context: DataContext, _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None, ) -> AllToAllTransformFn: """Generate function to partition each records of blocks.""" def shuffle_repartition_fn( refs: List[RefBundle], ctx: TaskContext, ) -> Tuple[List[RefBundle], StatsDict]: # If map_transformer is specified (e.g. from fusing # MapOperator->AllToAllOperator), we pass a map function which # is applied to each block before shuffling. map_transformer: Optional["MapTransformer"] = ctx.upstream_map_transformer upstream_map_fn = None if map_transformer: # NOTE: We override target max-block sizing of the previous # transformation to avoid unnecessary block shaping (if any) map_transformer.override_target_max_block_size(None) def upstream_map_fn(blocks): return map_transformer.apply_transform(blocks, ctx) shuffle_spec = ShuffleTaskSpec( target_shuffle_max_block_size=( ctx.target_max_block_size_override or data_context.target_max_block_size ), random_shuffle=False, upstream_map_fn=upstream_map_fn, ) if data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED: scheduler = PushBasedShuffleTaskScheduler(shuffle_spec) else: scheduler = PullBasedShuffleTaskScheduler(shuffle_spec) return scheduler.execute( refs, num_outputs, ctx, _debug_limit_execution_to_num_blocks=( _debug_limit_shuffle_execution_to_num_blocks ), ) def split_repartition_fn( refs: List[RefBundle], ctx: TaskContext, ) -> AllToAllTransformFnResult: shuffle_spec = ShuffleTaskSpec( target_shuffle_max_block_size=( ctx.target_max_block_size_override or data_context.target_max_block_size ), random_shuffle=False, ) scheduler = SplitRepartitionTaskScheduler(shuffle_spec) return scheduler.execute(refs, num_outputs, ctx) if shuffle: return shuffle_repartition_fn return split_repartition_fn