aggregate.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from typing import List, Optional, Union
  2. from ray.data._internal.execution.interfaces import (
  3. AllToAllTransformFn,
  4. RefBundle,
  5. TaskContext,
  6. )
  7. from ray.data._internal.execution.interfaces.transform_fn import (
  8. AllToAllTransformFnResult,
  9. )
  10. from ray.data._internal.planner.exchange.aggregate_task_spec import (
  11. SortAggregateTaskSpec,
  12. )
  13. from ray.data._internal.planner.exchange.pull_based_shuffle_task_scheduler import (
  14. PullBasedShuffleTaskScheduler,
  15. )
  16. from ray.data._internal.planner.exchange.push_based_shuffle_task_scheduler import (
  17. PushBasedShuffleTaskScheduler,
  18. )
  19. from ray.data._internal.planner.exchange.sort_task_spec import SortKey, SortTaskSpec
  20. from ray.data._internal.util import unify_ref_bundles_schema
  21. from ray.data.aggregate import AggregateFn
  22. from ray.data.context import DataContext, ShuffleStrategy
  23. def generate_aggregate_fn(
  24. key: Optional[Union[str, List[str]]],
  25. aggs: List[AggregateFn],
  26. batch_format: str,
  27. data_context: DataContext,
  28. _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
  29. ) -> AllToAllTransformFn:
  30. """Generate function to aggregate blocks by the specified key column or key
  31. function.
  32. """
  33. assert data_context.shuffle_strategy in [
  34. ShuffleStrategy.SORT_SHUFFLE_PULL_BASED,
  35. ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED,
  36. ]
  37. if len(aggs) == 0:
  38. raise ValueError("Aggregate requires at least one aggregation")
  39. def fn(
  40. refs: List[RefBundle],
  41. ctx: TaskContext,
  42. ) -> AllToAllTransformFnResult:
  43. blocks = []
  44. metadata = []
  45. for ref_bundle in refs:
  46. blocks.extend(ref_bundle.block_refs)
  47. metadata.extend(ref_bundle.metadata)
  48. if len(blocks) == 0:
  49. return (blocks, {})
  50. unified_schema = unify_ref_bundles_schema(refs)
  51. for agg_fn in aggs:
  52. agg_fn._validate(unified_schema)
  53. num_mappers = len(blocks)
  54. sort_key = SortKey(key)
  55. if key is None:
  56. num_outputs = 1
  57. boundaries = []
  58. else:
  59. # Use same number of output partitions.
  60. num_outputs = num_mappers
  61. sample_bar = ctx.sub_progress_bar_dict[
  62. SortTaskSpec.SORT_SAMPLE_SUB_PROGRESS_BAR_NAME
  63. ]
  64. # Sample boundaries for aggregate key.
  65. boundaries = SortTaskSpec.sample_boundaries(
  66. blocks, sort_key, num_outputs, sample_bar
  67. )
  68. agg_spec = SortAggregateTaskSpec(
  69. boundaries=boundaries,
  70. key=sort_key,
  71. aggs=aggs,
  72. batch_format=batch_format,
  73. )
  74. if data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED:
  75. scheduler = PushBasedShuffleTaskScheduler(agg_spec)
  76. elif data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PULL_BASED:
  77. scheduler = PullBasedShuffleTaskScheduler(agg_spec)
  78. else:
  79. raise ValueError(
  80. f"Invalid shuffle strategy '{data_context.shuffle_strategy}'"
  81. )
  82. return scheduler.execute(
  83. refs,
  84. num_outputs,
  85. ctx,
  86. _debug_limit_execution_to_num_blocks=(
  87. _debug_limit_shuffle_execution_to_num_blocks
  88. ),
  89. )
  90. return fn