memory.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from collections import OrderedDict
  2. from typing import Any
  3. import torch
  4. from ._utils import _device_t, _get_device_index
  5. __all__ = [
  6. "empty_cache",
  7. "get_memory_info",
  8. "max_memory_allocated",
  9. "max_memory_reserved",
  10. "memory_allocated",
  11. "memory_reserved",
  12. "memory_stats",
  13. "reset_accumulated_memory_stats",
  14. "reset_peak_memory_stats",
  15. ]
  16. def empty_cache() -> None:
  17. r"""Release all unoccupied cached memory currently held by the caching
  18. allocator so that those can be used in other application.
  19. .. note:: This function is a no-op if the memory allocator for the current
  20. :ref:`accelerator <accelerators>` has not been initialized.
  21. """
  22. if not torch._C._accelerator_isAllocatorInitialized():
  23. return
  24. torch._C._accelerator_emptyCache()
  25. def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
  26. r"""Return a dictionary of accelerator device memory allocator statistics for a given device index.
  27. The return value of this function is a dictionary of statistics, each of
  28. which is a non-negative integer.
  29. Core statistics:
  30. - ``"allocated.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  31. number of allocation requests received by the memory allocator.
  32. - ``"allocated_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  33. amount of allocated memory.
  34. - ``"segment.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  35. number of reserved segments from device memory allocation.
  36. - ``"reserved_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  37. amount of reserved memory.
  38. - ``"active.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  39. number of active memory blocks.
  40. - ``"active_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  41. amount of active memory.
  42. - ``"inactive_split.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  43. number of inactive, non-releasable memory blocks.
  44. - ``"inactive_split_bytes.{all,large_pool,small_pool}.{current,peak,allocated,freed}"``:
  45. amount of inactive, non-releasable memory.
  46. For these core statistics, values are broken down as follows.
  47. Pool type:
  48. - ``all``: combined statistics across all memory pools.
  49. - ``large_pool``: statistics for the large allocation pool
  50. (as of June 2025, for size >= 1MB allocations).
  51. - ``small_pool``: statistics for the small allocation pool
  52. (as of June 2025, for size < 1MB allocations).
  53. Metric type:
  54. - ``current``: current value of this metric.
  55. - ``peak``: maximum value of this metric.
  56. - ``allocated``: historical total increase in this metric.
  57. - ``freed``: historical total decrease in this metric.
  58. In addition to the core statistics, we also provide some simple event
  59. counters:
  60. - ``"num_alloc_retries"``: number of failed device memory allocation calls that
  61. result in a cache flush and retry.
  62. - ``"num_ooms"``: number of out-of-memory errors thrown.
  63. - ``"num_sync_all_streams"``: number of ``synchronize_and_free_events`` calls.
  64. - ``"num_device_alloc"``: number of device memory allocation calls.
  65. - ``"num_device_free"``: number of device memory free calls.
  66. Args:
  67. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  68. If not given, use :func:`torch.accelerator.current_device_index` by default.
  69. If a :class:`torch.device` or str is provided, its type must match the current
  70. :ref:`accelerator<accelerators>` device type.
  71. Returns:
  72. OrderedDict[str, Any]: an ordered dictionary mapping statistic names to their values.
  73. """
  74. if not torch._C._accelerator_isAllocatorInitialized():
  75. return OrderedDict()
  76. device_index = _get_device_index(device_index, optional=True)
  77. stats = torch._C._accelerator_getDeviceStats(device_index)
  78. flat_stats = []
  79. def flatten(prefix: str, value: Any) -> None:
  80. if isinstance(value, dict):
  81. for k, v in value.items():
  82. nested_prefix = f"{prefix}.{k}" if prefix else k
  83. flatten(nested_prefix, v)
  84. else:
  85. flat_stats.append((prefix, value))
  86. flatten("", stats)
  87. flat_stats.sort()
  88. # pyrefly: ignore [no-matching-overload]
  89. return OrderedDict(flat_stats)
  90. def memory_allocated(device_index: _device_t = None, /) -> int:
  91. r"""Return the current :ref:`accelerator<accelerators>` device memory occupied by tensors
  92. in bytes for a given device index.
  93. Args:
  94. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  95. If not given, use :func:`torch.accelerator.current_device_index` by default.
  96. If a :class:`torch.device` or str is provided, its type must match the current
  97. :ref:`accelerator<accelerators>` device type.
  98. Returns:
  99. int: the current memory occupied by live tensors (in bytes) within the current process.
  100. """
  101. return memory_stats(device_index).get("allocated_bytes.all.current", 0)
  102. def max_memory_allocated(device_index: _device_t = None, /) -> int:
  103. r"""Return the current :ref:`accelerator<accelerators>` maximum device memory occupied by tensors
  104. in bytes for a given device index.
  105. By default, this returns the peak allocated memory since the beginning of
  106. this program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to
  107. reset the starting point in tracking this metric.
  108. Args:
  109. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  110. If not given, use :func:`torch.accelerator.current_device_index` by default.
  111. If a :class:`torch.device` or str is provided, its type must match the current
  112. :ref:`accelerator<accelerators>` device type.
  113. Returns:
  114. int: the peak memory occupied by live tensors (in bytes) within the current process.
  115. """
  116. return memory_stats(device_index).get("allocated_bytes.all.peak", 0)
  117. def memory_reserved(device_index: _device_t = None, /) -> int:
  118. r"""Return the current :ref:`accelerator<accelerators>` device memory managed by the caching allocator
  119. in bytes for a given device index.
  120. Args:
  121. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  122. If not given, use :func:`torch.accelerator.current_device_index` by default.
  123. If a :class:`torch.device` or str is provided, its type must match the current
  124. :ref:`accelerator<accelerators>` device type.
  125. Returns:
  126. int: the current memory reserved by PyTorch (in bytes) within the current process.
  127. """
  128. return memory_stats(device_index).get("reserved_bytes.all.current", 0)
  129. def max_memory_reserved(device_index: _device_t = None, /) -> int:
  130. r"""Return the current :ref:`accelerator<accelerators>` maximum device memory managed by the caching allocator
  131. in bytes for a given device index.
  132. By default, this returns the peak cached memory since the beginning of this
  133. program. :func:`~torch.accelerator.reset_peak_memory_stats` can be used to reset
  134. the starting point in tracking this metric.
  135. Args:
  136. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  137. If not given, use :func:`torch.accelerator.current_device_index` by default.
  138. If a :class:`torch.device` or str is provided, its type must match the current
  139. :ref:`accelerator<accelerators>` device type.
  140. Returns:
  141. int: the peak memory reserved by PyTorch (in bytes) within the current process.
  142. """
  143. return memory_stats(device_index).get("reserved_bytes.all.peak", 0)
  144. def reset_accumulated_memory_stats(device_index: _device_t = None, /) -> None:
  145. r"""Reset the "accumulated" (historical) stats tracked by the current :ref:`accelerator<accelerators>`
  146. memory allocator for a given device index.
  147. Args:
  148. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  149. If not given, use :func:`torch.accelerator.current_device_index` by default.
  150. If a :class:`torch.device` or str is provided, its type must match the current
  151. :ref:`accelerator<accelerators>` device type.
  152. .. note:: This function is a no-op if the memory allocator for the current
  153. :ref:`accelerator <accelerators>` has not been initialized.
  154. """
  155. device_index = _get_device_index(device_index, optional=True)
  156. return torch._C._accelerator_resetAccumulatedStats(device_index)
  157. def reset_peak_memory_stats(device_index: _device_t = None, /) -> None:
  158. r"""Reset the "peak" stats tracked by the current :ref:`accelerator<accelerators>`
  159. memory allocator for a given device index.
  160. Args:
  161. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  162. If not given, use :func:`torch.accelerator.current_device_index` by default.
  163. If a :class:`torch.device` or str is provided, its type must match the current
  164. :ref:`accelerator<accelerators>` device type.
  165. .. note:: This function is a no-op if the memory allocator for the current
  166. :ref:`accelerator <accelerators>` has not been initialized.
  167. """
  168. device_index = _get_device_index(device_index, optional=True)
  169. return torch._C._accelerator_resetPeakStats(device_index)
  170. def get_memory_info(device_index: _device_t = None, /) -> tuple[int, int]:
  171. r"""Return the current device memory information for a given device index.
  172. Args:
  173. device_index (:class:`torch.device`, str, int, optional): the index of the device to target.
  174. If not given, use :func:`torch.accelerator.current_device_index` by default.
  175. If a :class:`torch.device` or str is provided, its type must match the current
  176. :ref:`accelerator<accelerators>` device type.
  177. Returns:
  178. tuple[int, int]: a tuple of two integers (free_memory, total_memory) in bytes.
  179. The first value is the free memory on the device (available across all processes and applications),
  180. The second value is the device's total hardware memory capacity.
  181. """
  182. device_index = _get_device_index(device_index, optional=True)
  183. # pyrefly: ignore [missing-attribute]
  184. return torch._C._accelerator_getMemoryInfo(device_index)