mtia_graph.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # pylint: disable=useless-parent-delegation
  2. from __future__ import annotations
  3. from typing_extensions import Self
  4. import torch
  5. _POOL_HANDLE = tuple[int, int]
  6. def graph_pool_handle() -> _POOL_HANDLE:
  7. """
  8. Return an opaque token representing the id of a graph memory pool.
  9. """
  10. # pyrefly: ignore [missing-attribute]
  11. return torch._C._mtia_graphPoolHandle()
  12. class MTIAGraph(torch._C._MTIAGraph):
  13. """
  14. Wrapper around a MTIA graph.
  15. """
  16. def __new__(cls, keep_graph: bool = False) -> Self:
  17. return super().__new__(cls, keep_graph)
  18. def capture_begin(self, pool: _POOL_HANDLE) -> None:
  19. """
  20. Begin capturing a MTIA graph.
  21. """
  22. super().capture_begin(pool)
  23. def capture_end(self) -> None:
  24. """
  25. End the capture of a MTIA graph.
  26. """
  27. super().capture_end()
  28. def instantiate(self) -> None:
  29. """
  30. Instantiate the captured MTIA graph.
  31. """
  32. super().instantiate()
  33. def replay(self) -> None:
  34. """
  35. Replay the captured MTIA graph.
  36. """
  37. super().replay()
  38. def reset(self) -> None:
  39. """
  40. Destroy the captured graph and reset the states.
  41. """
  42. super().reset()
  43. def pool(self) -> _POOL_HANDLE:
  44. """
  45. Return an opaque token representing the id of this graph's memory pool
  46. """
  47. return super().pool()
  48. class graph:
  49. default_capture_stream: torch.mtia.Stream | None = None
  50. def __init__(
  51. self,
  52. mtia_graph: MTIAGraph,
  53. pool: _POOL_HANDLE | None = None,
  54. stream: torch.mtia.Stream | None = None,
  55. ):
  56. if self.__class__.default_capture_stream is None:
  57. self.__class__.default_capture_stream = torch.mtia.current_stream()
  58. self.pool: tuple[()] | tuple[_POOL_HANDLE] = () if pool is None else (pool,)
  59. self.capture_stream = (
  60. stream if stream is not None else self.__class__.default_capture_stream
  61. )
  62. if self.capture_stream is None:
  63. raise AssertionError("capture_stream must not be None")
  64. self.stream_ctx = torch.mtia.stream(self.capture_stream)
  65. self.mtia_graph = mtia_graph
  66. def __enter__(self) -> None:
  67. torch.mtia.synchronize()
  68. torch.mtia.empty_cache()
  69. self.stream_ctx.__enter__()
  70. pool_arg = self.pool[0] if self.pool else (0, 0)
  71. self.mtia_graph.capture_begin(pool_arg)
  72. def __exit__(self, *args: object) -> None:
  73. self.mtia_graph.capture_end()
  74. self.stream_ctx.__exit__(*args)
  75. __all__ = [
  76. "MTIAGraph",
  77. "graph",
  78. "graph_pool_handle",
  79. ]