pass_manager.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # mypy: allow-untyped-defs
  2. import logging
  3. from collections.abc import Callable
  4. from functools import wraps
  5. from inspect import unwrap
  6. from typing import Optional
  7. logger = logging.getLogger(__name__)
  8. __all__ = [
  9. "PassManager",
  10. "inplace_wrapper",
  11. "log_hook",
  12. "loop_pass",
  13. "this_before_that_pass_constraint",
  14. "these_before_those_pass_constraint",
  15. ]
  16. # for callables which modify object inplace and return something other than
  17. # the object on which they act
  18. def inplace_wrapper(fn: Callable) -> Callable:
  19. """
  20. Convenience wrapper for passes which modify an object inplace. This
  21. wrapper makes them return the modified object instead.
  22. Args:
  23. fn (Callable[Object, Any])
  24. Returns:
  25. wrapped_fn (Callable[Object, Object])
  26. """
  27. @wraps(fn)
  28. def wrapped_fn(gm):
  29. fn(gm)
  30. return gm
  31. return wrapped_fn
  32. def log_hook(fn: Callable, level=logging.INFO) -> Callable:
  33. """
  34. Logs callable output.
  35. This is useful for logging output of passes. Note inplace_wrapper replaces
  36. the pass output with the modified object. If we want to log the original
  37. output, apply this wrapper before inplace_wrapper.
  38. ```
  39. def my_pass(d: Dict) -> bool:
  40. changed = False
  41. if "foo" in d:
  42. d["foo"] = "bar"
  43. changed = True
  44. return changed
  45. pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))])
  46. ```
  47. Args:
  48. fn (Callable[Type1, Type2])
  49. level: logging level (e.g. logging.INFO)
  50. Returns:
  51. wrapped_fn (Callable[Type1, Type2])
  52. """
  53. @wraps(fn)
  54. def wrapped_fn(gm):
  55. val = fn(gm)
  56. logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
  57. return val
  58. return wrapped_fn
  59. def loop_pass(
  60. base_pass: Callable,
  61. n_iter: Optional[int] = None,
  62. predicate: Optional[Callable] = None,
  63. ):
  64. """
  65. Convenience wrapper for passes which need to be applied multiple times.
  66. Exactly one of `n_iter`or `predicate` must be specified.
  67. Args:
  68. base_pass (Callable[Object, Object]): pass to be applied in loop
  69. n_iter (int, optional): number of times to loop pass
  70. predicate (Callable[Object, bool], optional):
  71. """
  72. if not ((n_iter is not None) ^ (predicate is not None)):
  73. raise AssertionError("Exactly one of `n_iter`or `predicate` must be specified.")
  74. @wraps(base_pass)
  75. def new_pass(source):
  76. output = source
  77. if n_iter is not None and n_iter > 0:
  78. for _ in range(n_iter):
  79. output = base_pass(output)
  80. elif predicate is not None:
  81. while predicate(output):
  82. output = base_pass(output)
  83. else:
  84. raise RuntimeError(
  85. f"loop_pass must be given positive int n_iter (given "
  86. f"{n_iter}) xor predicate (given {predicate})"
  87. )
  88. return output
  89. return new_pass
  90. # Pass Schedule Constraints:
  91. #
  92. # Implemented as 'depends on' operators. A constraint is satisfied iff a list
  93. # has a valid partial ordering according to this comparison operator.
  94. def _validate_pass_schedule_constraint(
  95. constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
  96. ):
  97. for i, a in enumerate(passes):
  98. for j, b in enumerate(passes[i + 1 :]):
  99. if constraint(a, b):
  100. continue
  101. raise RuntimeError(
  102. f"pass schedule constraint violated. Expected {a} before {b}"
  103. f" but found {a} at index {i} and {b} at index{j} in pass"
  104. f" list."
  105. )
  106. def this_before_that_pass_constraint(this: Callable, that: Callable):
  107. """
  108. Defines a partial order ('depends on' function) where `this` must occur
  109. before `that`.
  110. """
  111. def depends_on(a: Callable, b: Callable):
  112. return a != that or b != this
  113. return depends_on
  114. def these_before_those_pass_constraint(these: Callable, those: Callable):
  115. """
  116. Defines a partial order ('depends on' function) where `these` must occur
  117. before `those`. Where the inputs are 'unwrapped' before comparison.
  118. For example, the following pass list and constraint list would be invalid.
  119. ```
  120. passes = [
  121. loop_pass(pass_b, 3),
  122. loop_pass(pass_a, 5),
  123. ]
  124. constraints = [these_before_those_pass_constraint(pass_a, pass_b)]
  125. ```
  126. Args:
  127. these (Callable): pass which should occur first
  128. those (Callable): pass which should occur later
  129. Returns:
  130. depends_on (Callable[[Object, Object], bool]
  131. """
  132. def depends_on(a: Callable, b: Callable):
  133. return unwrap(a) != those or unwrap(b) != these
  134. return depends_on
  135. class PassManager:
  136. """
  137. Construct a PassManager.
  138. Collects passes and constraints. This defines the pass schedule, manages
  139. pass constraints and pass execution.
  140. Args:
  141. passes (Optional[List[Callable]]): list of passes. A pass is a
  142. callable which modifies an object and returns modified object
  143. constraint (Optional[List[Callable]]): list of constraints. A
  144. constraint is a callable which takes two passes (A, B) and returns
  145. True if A depends on B and False otherwise. See implementation of
  146. `this_before_that_pass_constraint` for example.
  147. """
  148. passes: list[Callable]
  149. constraints: list[Callable]
  150. _validated: bool = False
  151. def __init__(
  152. self,
  153. passes=None,
  154. constraints=None,
  155. ):
  156. self.passes = passes or []
  157. self.constraints = constraints or []
  158. @classmethod
  159. def build_from_passlist(cls, passes):
  160. pm = PassManager(passes)
  161. # TODO(alexbeloi): add constraint management/validation
  162. return pm
  163. def add_pass(self, _pass: Callable):
  164. self.passes.append(_pass)
  165. self._validated = False
  166. def add_constraint(self, constraint):
  167. self.constraints.append(constraint)
  168. self._validated = False
  169. def remove_pass(self, _passes: list[str]):
  170. if _passes is None:
  171. return
  172. passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
  173. self.passes = passes_left
  174. self._validated = False
  175. def replace_pass(self, _target, _replacement):
  176. passes_left = []
  177. for ps in self.passes:
  178. if ps.__name__ == _target.__name__:
  179. passes_left.append(_replacement)
  180. else:
  181. passes_left.append(ps)
  182. self.passes = passes_left
  183. self._validated = False
  184. def validate(self):
  185. """
  186. Validates that current pass schedule defined by `self.passes` is valid
  187. according to all constraints in `self.constraints`
  188. """
  189. if self._validated:
  190. return
  191. for constraint in self.constraints:
  192. _validate_pass_schedule_constraint(constraint, self.passes)
  193. self._validated = True
  194. def __call__(self, source):
  195. self.validate()
  196. out = source
  197. for _pass in self.passes:
  198. out = _pass(out)
  199. return out