repartition.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. from typing import List, Optional, Tuple
  2. from ray.data._internal.execution.interfaces import (
  3. AllToAllTransformFn,
  4. RefBundle,
  5. TaskContext,
  6. )
  7. from ray.data._internal.execution.interfaces.transform_fn import (
  8. AllToAllTransformFnResult,
  9. )
  10. from ray.data._internal.execution.operators.map_transformer import MapTransformer
  11. from ray.data._internal.planner.exchange.pull_based_shuffle_task_scheduler import (
  12. PullBasedShuffleTaskScheduler,
  13. )
  14. from ray.data._internal.planner.exchange.push_based_shuffle_task_scheduler import (
  15. PushBasedShuffleTaskScheduler,
  16. )
  17. from ray.data._internal.planner.exchange.shuffle_task_spec import ShuffleTaskSpec
  18. from ray.data._internal.planner.exchange.split_repartition_task_scheduler import (
  19. SplitRepartitionTaskScheduler,
  20. )
  21. from ray.data._internal.stats import StatsDict
  22. from ray.data.context import DataContext, ShuffleStrategy
  23. def generate_repartition_fn(
  24. num_outputs: int,
  25. shuffle: bool,
  26. data_context: DataContext,
  27. _debug_limit_shuffle_execution_to_num_blocks: Optional[int] = None,
  28. ) -> AllToAllTransformFn:
  29. """Generate function to partition each records of blocks."""
  30. def shuffle_repartition_fn(
  31. refs: List[RefBundle],
  32. ctx: TaskContext,
  33. ) -> Tuple[List[RefBundle], StatsDict]:
  34. # If map_transformer is specified (e.g. from fusing
  35. # MapOperator->AllToAllOperator), we pass a map function which
  36. # is applied to each block before shuffling.
  37. map_transformer: Optional["MapTransformer"] = ctx.upstream_map_transformer
  38. upstream_map_fn = None
  39. if map_transformer:
  40. # NOTE: We override target max-block sizing of the previous
  41. # transformation to avoid unnecessary block shaping (if any)
  42. map_transformer.override_target_max_block_size(None)
  43. def upstream_map_fn(blocks):
  44. return map_transformer.apply_transform(blocks, ctx)
  45. shuffle_spec = ShuffleTaskSpec(
  46. target_shuffle_max_block_size=(
  47. ctx.target_max_block_size_override or data_context.target_max_block_size
  48. ),
  49. random_shuffle=False,
  50. upstream_map_fn=upstream_map_fn,
  51. )
  52. if data_context.shuffle_strategy == ShuffleStrategy.SORT_SHUFFLE_PUSH_BASED:
  53. scheduler = PushBasedShuffleTaskScheduler(shuffle_spec)
  54. else:
  55. scheduler = PullBasedShuffleTaskScheduler(shuffle_spec)
  56. return scheduler.execute(
  57. refs,
  58. num_outputs,
  59. ctx,
  60. _debug_limit_execution_to_num_blocks=(
  61. _debug_limit_shuffle_execution_to_num_blocks
  62. ),
  63. )
  64. def split_repartition_fn(
  65. refs: List[RefBundle],
  66. ctx: TaskContext,
  67. ) -> AllToAllTransformFnResult:
  68. shuffle_spec = ShuffleTaskSpec(
  69. target_shuffle_max_block_size=(
  70. ctx.target_max_block_size_override or data_context.target_max_block_size
  71. ),
  72. random_shuffle=False,
  73. )
  74. scheduler = SplitRepartitionTaskScheduler(shuffle_spec)
  75. return scheduler.execute(refs, num_outputs, ctx)
  76. if shuffle:
  77. return shuffle_repartition_fn
  78. return split_repartition_fn