| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253 |
- # mypy: allow-untyped-defs
- import logging
- from collections.abc import Callable
- from functools import wraps
- from inspect import unwrap
- from typing import Optional
- logger = logging.getLogger(__name__)
- __all__ = [
- "PassManager",
- "inplace_wrapper",
- "log_hook",
- "loop_pass",
- "this_before_that_pass_constraint",
- "these_before_those_pass_constraint",
- ]
- # for callables which modify object inplace and return something other than
- # the object on which they act
- def inplace_wrapper(fn: Callable) -> Callable:
- """
- Convenience wrapper for passes which modify an object inplace. This
- wrapper makes them return the modified object instead.
- Args:
- fn (Callable[Object, Any])
- Returns:
- wrapped_fn (Callable[Object, Object])
- """
- @wraps(fn)
- def wrapped_fn(gm):
- fn(gm)
- return gm
- return wrapped_fn
- def log_hook(fn: Callable, level=logging.INFO) -> Callable:
- """
- Logs callable output.
- This is useful for logging output of passes. Note inplace_wrapper replaces
- the pass output with the modified object. If we want to log the original
- output, apply this wrapper before inplace_wrapper.
- ```
- def my_pass(d: Dict) -> bool:
- changed = False
- if "foo" in d:
- d["foo"] = "bar"
- changed = True
- return changed
- pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))])
- ```
- Args:
- fn (Callable[Type1, Type2])
- level: logging level (e.g. logging.INFO)
- Returns:
- wrapped_fn (Callable[Type1, Type2])
- """
- @wraps(fn)
- def wrapped_fn(gm):
- val = fn(gm)
- logger.log(level, "Ran pass %s\t Return value: %s", fn, val)
- return val
- return wrapped_fn
- def loop_pass(
- base_pass: Callable,
- n_iter: Optional[int] = None,
- predicate: Optional[Callable] = None,
- ):
- """
- Convenience wrapper for passes which need to be applied multiple times.
- Exactly one of `n_iter`or `predicate` must be specified.
- Args:
- base_pass (Callable[Object, Object]): pass to be applied in loop
- n_iter (int, optional): number of times to loop pass
- predicate (Callable[Object, bool], optional):
- """
- if not ((n_iter is not None) ^ (predicate is not None)):
- raise AssertionError("Exactly one of `n_iter`or `predicate` must be specified.")
- @wraps(base_pass)
- def new_pass(source):
- output = source
- if n_iter is not None and n_iter > 0:
- for _ in range(n_iter):
- output = base_pass(output)
- elif predicate is not None:
- while predicate(output):
- output = base_pass(output)
- else:
- raise RuntimeError(
- f"loop_pass must be given positive int n_iter (given "
- f"{n_iter}) xor predicate (given {predicate})"
- )
- return output
- return new_pass
- # Pass Schedule Constraints:
- #
- # Implemented as 'depends on' operators. A constraint is satisfied iff a list
- # has a valid partial ordering according to this comparison operator.
- def _validate_pass_schedule_constraint(
- constraint: Callable[[Callable, Callable], bool], passes: list[Callable]
- ):
- for i, a in enumerate(passes):
- for j, b in enumerate(passes[i + 1 :]):
- if constraint(a, b):
- continue
- raise RuntimeError(
- f"pass schedule constraint violated. Expected {a} before {b}"
- f" but found {a} at index {i} and {b} at index{j} in pass"
- f" list."
- )
- def this_before_that_pass_constraint(this: Callable, that: Callable):
- """
- Defines a partial order ('depends on' function) where `this` must occur
- before `that`.
- """
- def depends_on(a: Callable, b: Callable):
- return a != that or b != this
- return depends_on
- def these_before_those_pass_constraint(these: Callable, those: Callable):
- """
- Defines a partial order ('depends on' function) where `these` must occur
- before `those`. Where the inputs are 'unwrapped' before comparison.
- For example, the following pass list and constraint list would be invalid.
- ```
- passes = [
- loop_pass(pass_b, 3),
- loop_pass(pass_a, 5),
- ]
- constraints = [these_before_those_pass_constraint(pass_a, pass_b)]
- ```
- Args:
- these (Callable): pass which should occur first
- those (Callable): pass which should occur later
- Returns:
- depends_on (Callable[[Object, Object], bool]
- """
- def depends_on(a: Callable, b: Callable):
- return unwrap(a) != those or unwrap(b) != these
- return depends_on
- class PassManager:
- """
- Construct a PassManager.
- Collects passes and constraints. This defines the pass schedule, manages
- pass constraints and pass execution.
- Args:
- passes (Optional[List[Callable]]): list of passes. A pass is a
- callable which modifies an object and returns modified object
- constraint (Optional[List[Callable]]): list of constraints. A
- constraint is a callable which takes two passes (A, B) and returns
- True if A depends on B and False otherwise. See implementation of
- `this_before_that_pass_constraint` for example.
- """
- passes: list[Callable]
- constraints: list[Callable]
- _validated: bool = False
- def __init__(
- self,
- passes=None,
- constraints=None,
- ):
- self.passes = passes or []
- self.constraints = constraints or []
- @classmethod
- def build_from_passlist(cls, passes):
- pm = PassManager(passes)
- # TODO(alexbeloi): add constraint management/validation
- return pm
- def add_pass(self, _pass: Callable):
- self.passes.append(_pass)
- self._validated = False
- def add_constraint(self, constraint):
- self.constraints.append(constraint)
- self._validated = False
- def remove_pass(self, _passes: list[str]):
- if _passes is None:
- return
- passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
- self.passes = passes_left
- self._validated = False
- def replace_pass(self, _target, _replacement):
- passes_left = []
- for ps in self.passes:
- if ps.__name__ == _target.__name__:
- passes_left.append(_replacement)
- else:
- passes_left.append(ps)
- self.passes = passes_left
- self._validated = False
- def validate(self):
- """
- Validates that current pass schedule defined by `self.passes` is valid
- according to all constraints in `self.constraints`
- """
- if self._validated:
- return
- for constraint in self.constraints:
- _validate_pass_schedule_constraint(constraint, self.passes)
- self._validated = True
- def __call__(self, source):
- self.validate()
- out = source
- for _pass in self.passes:
- out = _pass(out)
- return out
|