| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- from typing import List, Optional, Union
- from ray.data._internal.execution.interfaces import (
- AllToAllTransformFn,
- RefBundle,
- TaskContext,
- )
- from ray.data._internal.execution.interfaces.transform_fn import (
- AllToAllTransformFnResult,
- )
- from ray.data._internal.planner.exchange.aggregate_task_spec import (
- SortAggregateTaskSpec,
- )
- from ray.data._internal.planner.exchange.pull_based_shuffle_task_scheduler import (
- PullBasedShuffleTaskScheduler,
- )
- from ray.data._internal.planner.exchange.push_based_shuffle_task_scheduler import (
- PushBasedShuffleTaskScheduler,
- )
- from ray.data._internal.planner.exchange.sort_task_spec import SortKey, SortTaskSpec
- from ray.data._internal.util import unify_ref_bundles_schema
- from ray.data.aggregate import AggregateFn
- from ray.data.context import DataContext, ShuffleStrategy
- def generate_aggregate_fn(
- key: Optional[Union[str, List[str]]],
- aggs: List[AggregateFn],
- batch_format: str,
- data_context: DataContext,
- _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
- ) -> AllToAllTransformFn:
- """Generate function to aggregate blocks by the specified key column or key
- function.
- """
- assert data_context.shuffle_strategy in [
- ShuffleStrategy.SORT_SHUFFLE_PULL_BASED,
- ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED,
- ]
- if len(aggs) == 0:
- raise ValueError("Aggregate requires at least one aggregation")
- def fn(
- refs: List[RefBundle],
- ctx: TaskContext,
- ) -> AllToAllTransformFnResult:
- blocks = []
- metadata = []
- for ref_bundle in refs:
- blocks.extend(ref_bundle.block_refs)
- metadata.extend(ref_bundle.metadata)
- if len(blocks) == 0:
- return (blocks, {})
- unified_schema = unify_ref_bundles_schema(refs)
- for agg_fn in aggs:
- agg_fn._validate(unified_schema)
- num_mappers = len(blocks)
- sort_key = SortKey(key)
- if key is None:
- num_outputs = 1
- boundaries = []
- else:
- # Use same number of output partitions.
- num_outputs = num_mappers
- sample_bar = ctx.sub_progress_bar_dict[
- SortTaskSpec.SORT_SAMPLE_SUB_PROGRESS_BAR_NAME
- ]
- # Sample boundaries for aggregate key.
- boundaries = SortTaskSpec.sample_boundaries(
- blocks, sort_key, num_outputs, sample_bar
- )
- agg_spec = SortAggregateTaskSpec(
- boundaries=boundaries,
- key=sort_key,
- aggs=aggs,
- batch_format=batch_format,
- )
- if data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED:
- scheduler = PushBasedShuffleTaskScheduler(agg_spec)
- elif data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PULL_BASED:
- scheduler = PullBasedShuffleTaskScheduler(agg_spec)
- else:
- raise ValueError(
- f"Invalid shuffle strategy '{data_context.shuffle_strategy}'"
- )
- return scheduler.execute(
- refs,
- num_outputs,
- ctx,
- _debug_limit_execution_to_num_blocks=(
- _debug_limit_shuffle_execution_to_num_blocks
- ),
- )
- return fn
|