ruleset.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import collections
  2. from dataclasses import dataclass, field
  3. from typing import Dict, Iterator, List, Optional, Type
  4. from ray.data._internal.logical.interfaces import Rule
  5. from ray.util.annotations import DeveloperAPI
  6. @DeveloperAPI
  7. class Ruleset:
  8. """A collection of rules to apply to a plan.
  9. This is a utility class to ensure that, if rules depend on each other, they're
  10. applied in a correct order.
  11. """
  12. @dataclass(frozen=True)
  13. class _Node:
  14. rule: Type[Rule]
  15. dependents: List["Ruleset._Node"] = field(default_factory=list)
  16. def __init__(self, rules: Optional[List[Type[Rule]]] = None):
  17. if rules is None:
  18. rules = []
  19. self._rules = list(rules)
  20. def add(self, rule: Type[Rule]):
  21. if rule in self._rules:
  22. raise ValueError(f"Rule {rule} already in ruleset")
  23. self._rules.append(rule)
  24. if self._contains_cycle():
  25. raise ValueError("Cannot add rule that would create a cycle")
  26. def remove(self, rule: Type[Rule]):
  27. if rule not in self._rules:
  28. raise ValueError(f"Rule {rule} not found in ruleset")
  29. self._rules.remove(rule)
  30. def __iter__(self) -> Iterator[Type[Rule]]:
  31. """Iterate over the rules in this ruleset.
  32. This method yields rules in dependency order. For example, if B depends on A,
  33. then this method yields A before B. The order is otherwise undefined.
  34. """
  35. roots = self._build_graph()
  36. queue = collections.deque(roots)
  37. while queue:
  38. node = queue.popleft()
  39. yield node.rule
  40. queue.extend(node.dependents)
  41. def _build_graph(self) -> List["Ruleset._Node"]:
  42. # NOTE: Because the number of rules will always be relatively small, I've opted
  43. # for a simpler but inefficient implementation.
  44. # Step 1: Add edges from dependencies to their dependants.
  45. rule_to_node: Dict[Type[Rule], "Ruleset._Node"] = {
  46. rule: Ruleset._Node(rule) for rule in self._rules
  47. }
  48. for rule in self._rules:
  49. node = rule_to_node[rule]
  50. # These are rules that must be applied *before* this rule.
  51. for dependency in rule.dependencies():
  52. if dependency in rule_to_node:
  53. rule_to_node[dependency].dependents.append(node)
  54. # These are rules that must be applied *after* this rule.
  55. for dependent in rule.dependents():
  56. if dependent in rule_to_node:
  57. node.dependents.append(rule_to_node[dependent])
  58. # Step 2: Determine which nodes are roots.
  59. roots = list(rule_to_node.values())
  60. for node in rule_to_node.values():
  61. for dependent in node.dependents:
  62. if dependent in roots:
  63. roots.remove(dependent)
  64. return roots
  65. def _contains_cycle(self) -> bool:
  66. if not self._rules:
  67. return
  68. # If the graph contains nodes but there aren't any root nodes, it means that
  69. # there must be a cycle.
  70. roots = self._build_graph()
  71. return not roots