planner.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. import functools
  2. import warnings
  3. from typing import Callable, Dict, List, Optional, Tuple, Type, TypeVar
  4. from ray import ObjectRef
  5. from ray.data._internal.execution.execution_callback import add_execution_callback
  6. from ray.data._internal.execution.interfaces import PhysicalOperator
  7. from ray.data._internal.execution.operators.aggregate_num_rows import (
  8. AggregateNumRows,
  9. )
  10. from ray.data._internal.execution.operators.input_data_buffer import (
  11. InputDataBuffer,
  12. )
  13. from ray.data._internal.execution.operators.join import JoinOperator
  14. from ray.data._internal.execution.operators.limit_operator import LimitOperator
  15. from ray.data._internal.execution.operators.output_splitter import OutputSplitter
  16. from ray.data._internal.execution.operators.union_operator import UnionOperator
  17. from ray.data._internal.execution.operators.zip_operator import ZipOperator
  18. from ray.data._internal.logical.interfaces import (
  19. LogicalOperator,
  20. LogicalPlan,
  21. PhysicalPlan,
  22. )
  23. from ray.data._internal.logical.operators import (
  24. AbstractAllToAll,
  25. AbstractFrom,
  26. AbstractUDFMap,
  27. Count,
  28. Download,
  29. Filter,
  30. InputData,
  31. Join,
  32. Limit,
  33. Project,
  34. Read,
  35. StreamingRepartition,
  36. StreamingSplit,
  37. Union,
  38. Write,
  39. Zip,
  40. )
  41. from ray.data._internal.planner.checkpoint import (
  42. plan_read_op_with_checkpoint_filter,
  43. plan_write_op_with_checkpoint_writer,
  44. )
  45. from ray.data._internal.planner.plan_all_to_all_op import plan_all_to_all_op
  46. from ray.data._internal.planner.plan_download_op import plan_download_op
  47. from ray.data._internal.planner.plan_read_op import plan_read_op
  48. from ray.data._internal.planner.plan_udf_map_op import (
  49. plan_filter_op,
  50. plan_project_op,
  51. plan_streaming_repartition_op,
  52. plan_udf_map_op,
  53. )
  54. from ray.data._internal.planner.plan_write_op import plan_write_op
  55. from ray.data.checkpoint.load_checkpoint_callback import LoadCheckpointCallback
  56. from ray.data.context import DataContext
  57. LogicalOperatorType = TypeVar("LogicalOperatorType", bound=LogicalOperator)
  58. PlanLogicalOpFn = Callable[
  59. [LogicalOperatorType, List[PhysicalOperator], DataContext], PhysicalOperator
  60. ]
  61. def plan_input_data_op(
  62. logical_op: InputData,
  63. physical_children: List[PhysicalOperator],
  64. data_context: DataContext,
  65. ) -> PhysicalOperator:
  66. """Get the corresponding DAG of physical operators for InputData."""
  67. assert len(physical_children) == 0
  68. return InputDataBuffer(
  69. data_context,
  70. input_data=logical_op.input_data,
  71. )
  72. def plan_from_op(
  73. op: AbstractFrom,
  74. physical_children: List[PhysicalOperator],
  75. data_context: DataContext,
  76. ) -> PhysicalOperator:
  77. assert len(physical_children) == 0
  78. return InputDataBuffer(data_context, op.input_data)
  79. def plan_zip_op(_, physical_children, data_context):
  80. assert len(physical_children) >= 2
  81. return ZipOperator(data_context, *physical_children)
  82. def plan_union_op(_, physical_children, data_context):
  83. assert len(physical_children) >= 2
  84. return UnionOperator(data_context, *physical_children)
  85. def plan_limit_op(logical_op, physical_children, data_context):
  86. assert len(physical_children) == 1
  87. return LimitOperator(logical_op.limit, physical_children[0], data_context)
  88. def plan_count_op(logical_op, physical_children, data_context):
  89. assert len(physical_children) == 1
  90. return AggregateNumRows(
  91. [physical_children[0]], data_context, column_name=Count.COLUMN_NAME
  92. )
  93. def plan_join_op(
  94. logical_op: Join,
  95. physical_children: List[PhysicalOperator],
  96. data_context: DataContext,
  97. ) -> PhysicalOperator:
  98. assert len(physical_children) == 2
  99. return JoinOperator(
  100. data_context=data_context,
  101. left_input_op=physical_children[0],
  102. right_input_op=physical_children[1],
  103. join_type=logical_op.join_type,
  104. left_key_columns=logical_op.left_key_columns,
  105. right_key_columns=logical_op.right_key_columns,
  106. left_columns_suffix=logical_op.left_columns_suffix,
  107. right_columns_suffix=logical_op.right_columns_suffix,
  108. num_partitions=logical_op.num_outputs,
  109. partition_size_hint=logical_op.partition_size_hint,
  110. aggregator_ray_remote_args_override=logical_op.aggregator_ray_remote_args,
  111. )
  112. def plan_streaming_split_op(
  113. logical_op: StreamingSplit,
  114. physical_children: List[PhysicalOperator],
  115. data_context: DataContext,
  116. ):
  117. assert len(physical_children) == 1
  118. return OutputSplitter(
  119. physical_children[0],
  120. n=logical_op.num_splits,
  121. equal=logical_op.equal,
  122. data_context=data_context,
  123. locality_hints=logical_op.locality_hints,
  124. )
  125. class Planner:
  126. """The planner to convert optimized logical to physical operators.
  127. Note that planner is only doing operators conversion. Physical optimization work is
  128. done by physical optimizer.
  129. """
  130. _DEFAULT_PLAN_FNS = {
  131. Read: plan_read_op,
  132. InputData: plan_input_data_op,
  133. Write: plan_write_op,
  134. AbstractFrom: plan_from_op,
  135. Filter: plan_filter_op,
  136. AbstractUDFMap: plan_udf_map_op,
  137. AbstractAllToAll: plan_all_to_all_op,
  138. Union: plan_union_op,
  139. Zip: plan_zip_op,
  140. Limit: plan_limit_op,
  141. Count: plan_count_op,
  142. Project: plan_project_op,
  143. StreamingRepartition: plan_streaming_repartition_op,
  144. Join: plan_join_op,
  145. StreamingSplit: plan_streaming_split_op,
  146. Download: plan_download_op,
  147. }
  148. # Operators that support checkpoint filtering. Subclasses can override.
  149. _CHECKPOINT_FILTER_OPS = (Read,)
  150. def __init__(self):
  151. self._supports_checkpointing = False
  152. self._plan_fns_for_checkpointing = {}
  153. def plan(self, logical_plan: LogicalPlan) -> PhysicalPlan:
  154. """Convert logical to physical operators recursively in post-order."""
  155. checkpoint_config = logical_plan.context.checkpoint_config
  156. if checkpoint_config is not None and self._check_supports_checkpointing(
  157. logical_plan
  158. ):
  159. self._supports_checkpointing = True
  160. checkpoint_callback = self._create_checkpoint_callback(checkpoint_config)
  161. add_execution_callback(checkpoint_callback, logical_plan.context)
  162. load_checkpoint = checkpoint_callback.load_checkpoint
  163. # Dynamically set the plan functions for checkpointing because they
  164. # need to a reference to the checkpoint ref.
  165. self._plan_fns_for_checkpointing = self._get_plan_fns_for_checkpointing(
  166. load_checkpoint
  167. )
  168. elif checkpoint_config is not None:
  169. assert not self._check_supports_checkpointing(logical_plan)
  170. warnings.warn(
  171. "You've enabled checkpointing, but the logical plan doesn't support "
  172. "checkpointing. Checkpointing will be disabled."
  173. )
  174. physical_dag, op_map = self._plan_recursively(
  175. logical_plan.dag, logical_plan.context
  176. )
  177. physical_plan = PhysicalPlan(physical_dag, op_map, logical_plan.context)
  178. return physical_plan
  179. def get_plan_fn(self, logical_op: LogicalOperator) -> PlanLogicalOpFn:
  180. if self._supports_checkpointing:
  181. assert self._plan_fns_for_checkpointing
  182. plan_fn = find_plan_fn(logical_op, self._plan_fns_for_checkpointing)
  183. if plan_fn is not None:
  184. return plan_fn
  185. plan_fn = find_plan_fn(logical_op, self._DEFAULT_PLAN_FNS)
  186. if plan_fn is not None:
  187. return plan_fn
  188. raise ValueError(
  189. f"Found unknown logical operator during planning: {logical_op}"
  190. )
  191. def _plan_recursively(
  192. self, logical_op: LogicalOperator, data_context: DataContext
  193. ) -> Tuple[PhysicalOperator, Dict[LogicalOperator, PhysicalOperator]]:
  194. """Plan a logical operator and its input dependencies recursively.
  195. Args:
  196. logical_op: The logical operator to plan.
  197. data_context: The data context.
  198. Returns:
  199. A tuple of the physical operator corresponding to the logical operator, and
  200. a mapping from physical to logical operators.
  201. """
  202. op_map: Dict[PhysicalOperator, LogicalOperator] = {}
  203. # Plan the input dependencies first.
  204. physical_children = []
  205. for child in logical_op.input_dependencies:
  206. physical_child, child_op_map = self._plan_recursively(child, data_context)
  207. physical_children.append(physical_child)
  208. op_map.update(child_op_map)
  209. plan_fn = self.get_plan_fn(logical_op)
  210. # We will call `set_logical_operators()` in the following for-loop,
  211. # no need to do it here.
  212. physical_op = plan_fn(logical_op, physical_children, data_context)
  213. # Traverse up the DAG, and set the mapping from physical to logical operators.
  214. # At this point, all physical operators without logical operators set
  215. # must have been created by the current logical operator.
  216. queue = [physical_op]
  217. while queue:
  218. curr_physical_op = queue.pop()
  219. # Once we find an operator with a logical operator set, we can stop.
  220. if curr_physical_op._logical_operators:
  221. break
  222. curr_physical_op.set_logical_operators(logical_op)
  223. # Add this operator to the op_map so optimizer can find it
  224. op_map[curr_physical_op] = logical_op
  225. queue.extend(curr_physical_op.input_dependencies)
  226. # Also add the final operator (in case the loop didn't catch it)
  227. op_map[physical_op] = logical_op
  228. return physical_op, op_map
  229. def _create_checkpoint_callback(self, checkpoint_config) -> LoadCheckpointCallback:
  230. """Factory method to create the LoadCheckpointCallback.
  231. Subclasses can override this to use a different callback implementation.
  232. """
  233. return LoadCheckpointCallback(checkpoint_config)
  234. def _get_plan_fns_for_checkpointing(
  235. self,
  236. load_checkpoint: Callable[[], ObjectRef],
  237. ) -> Dict[Type[LogicalOperator], PlanLogicalOpFn]:
  238. plan_fns = {
  239. Read: functools.partial(
  240. plan_read_op_with_checkpoint_filter,
  241. load_checkpoint=load_checkpoint,
  242. ),
  243. Write: plan_write_op_with_checkpoint_writer,
  244. }
  245. return plan_fns
  246. def _check_supports_checkpointing(self, logical_plan: LogicalPlan) -> bool:
  247. """Check if the logical plan supports checkpointing.
  248. Subclasses can override _CHECKPOINT_FILTER_OPS to support more operators.
  249. """
  250. if not isinstance(logical_plan.dag, (Write, StreamingSplit)):
  251. return False
  252. def _all_paths_contain_checkpoint_filter(op: LogicalOperator) -> bool:
  253. if isinstance(op, self._CHECKPOINT_FILTER_OPS):
  254. return True
  255. return all(
  256. _all_paths_contain_checkpoint_filter(input_dep)
  257. for input_dep in op.input_dependencies
  258. )
  259. return _all_paths_contain_checkpoint_filter(logical_plan.dag)
  260. def find_plan_fn(
  261. logical_op: LogicalOperator, plan_fns: Dict[Type[LogicalOperator], PlanLogicalOpFn]
  262. ) -> Optional[PlanLogicalOpFn]:
  263. """Find the plan function for a logical operator.
  264. This function goes through the plan functions in order and returns the first one
  265. that is an instance of the logical operator type.
  266. Args:
  267. logical_op: The logical operator to find the plan function for.
  268. plan_fns: The dictionary of plan functions.
  269. Returns:
  270. The plan function for the logical operator, or None if no plan function is
  271. found.
  272. """
  273. # TODO: This implementation doesn't account for type hierarchies conflicts or
  274. # multiple inheritance.
  275. for op_type, plan_fn in plan_fns.items():
  276. if isinstance(logical_op, op_type):
  277. return plan_fn
  278. return None