_ir_utils.py 894 B

123456789101112131415161718192021222324252627282930313233
  1. from types import TracebackType
  2. from typing import Optional, Union
  3. import torch
  4. class _InsertPoint:
  5. def __init__(
  6. self,
  7. insert_point_graph: torch._C.Graph,
  8. insert_point: Union[torch._C.Node, torch._C.Block],
  9. ) -> None:
  10. self.insert_point = insert_point
  11. self.g = insert_point_graph
  12. self.guard = None
  13. def __enter__(self) -> None:
  14. self.prev_insert_point = self.g.insertPoint()
  15. self.g.setInsertPoint(self.insert_point)
  16. def __exit__(
  17. self,
  18. exc_type: Optional[type[BaseException]],
  19. exc_val: Optional[BaseException],
  20. exc_tb: Optional[TracebackType],
  21. ) -> None:
  22. self.g.setInsertPoint(self.prev_insert_point)
  23. def insert_point_guard(
  24. self: torch._C.Graph, insert_point: Union[torch._C.Node, torch._C.Block]
  25. ) -> _InsertPoint:
  26. return _InsertPoint(self, insert_point)