| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- # pylint: disable=useless-parent-delegation
- from __future__ import annotations
- from typing_extensions import Self
- import torch
- _POOL_HANDLE = tuple[int, int]
- def graph_pool_handle() -> _POOL_HANDLE:
- """
- Return an opaque token representing the id of a graph memory pool.
- """
- # pyrefly: ignore [missing-attribute]
- return torch._C._mtia_graphPoolHandle()
- class MTIAGraph(torch._C._MTIAGraph):
- """
- Wrapper around a MTIA graph.
- """
- def __new__(cls, keep_graph: bool = False) -> Self:
- return super().__new__(cls, keep_graph)
- def capture_begin(self, pool: _POOL_HANDLE) -> None:
- """
- Begin capturing a MTIA graph.
- """
- super().capture_begin(pool)
- def capture_end(self) -> None:
- """
- End the capture of a MTIA graph.
- """
- super().capture_end()
- def instantiate(self) -> None:
- """
- Instantiate the captured MTIA graph.
- """
- super().instantiate()
- def replay(self) -> None:
- """
- Replay the captured MTIA graph.
- """
- super().replay()
- def reset(self) -> None:
- """
- Destroy the captured graph and reset the states.
- """
- super().reset()
- def pool(self) -> _POOL_HANDLE:
- """
- Return an opaque token representing the id of this graph's memory pool
- """
- return super().pool()
- class graph:
- default_capture_stream: torch.mtia.Stream | None = None
- def __init__(
- self,
- mtia_graph: MTIAGraph,
- pool: _POOL_HANDLE | None = None,
- stream: torch.mtia.Stream | None = None,
- ):
- if self.__class__.default_capture_stream is None:
- self.__class__.default_capture_stream = torch.mtia.current_stream()
- self.pool: tuple[()] | tuple[_POOL_HANDLE] = () if pool is None else (pool,)
- self.capture_stream = (
- stream if stream is not None else self.__class__.default_capture_stream
- )
- if self.capture_stream is None:
- raise AssertionError("capture_stream must not be None")
- self.stream_ctx = torch.mtia.stream(self.capture_stream)
- self.mtia_graph = mtia_graph
- def __enter__(self) -> None:
- torch.mtia.synchronize()
- torch.mtia.empty_cache()
- self.stream_ctx.__enter__()
- pool_arg = self.pool[0] if self.pool else (0, 0)
- self.mtia_graph.capture_begin(pool_arg)
- def __exit__(self, *args: object) -> None:
- self.mtia_graph.capture_end()
- self.stream_ctx.__exit__(*args)
- __all__ = [
- "MTIAGraph",
- "graph",
- "graph_pool_handle",
- ]
|