memory.py 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # pyre-strict
  2. r"""This package adds support for device memory management implemented in MTIA."""
  3. from typing import Any
  4. import torch
  5. from . import Device, is_initialized
  6. from ._utils import _get_device_index
  7. def memory_stats(device: Device = None) -> dict[str, Any]:
  8. r"""Return a dictionary of MTIA memory allocator statistics for a given device.
  9. Args:
  10. device (torch.device, str, or int, optional) selected device. Returns
  11. statistics for the current device, given by current_device(),
  12. if device is None (default).
  13. """
  14. if not is_initialized():
  15. return {}
  16. return torch._C._mtia_memoryStats(_get_device_index(device, optional=True))
  17. def max_memory_allocated(device: Device = None) -> int:
  18. r"""Return the maximum memory allocated in bytes for a given device.
  19. Args:
  20. device (torch.device, str, or int, optional) selected device. Returns
  21. statistics for the current device, given by current_device(),
  22. if device is None (default).
  23. """
  24. if not is_initialized():
  25. return 0
  26. return memory_stats(device).get("dram", 0).get("peak_bytes", 0)
  27. def memory_allocated(device: Device = None) -> int:
  28. r"""Return the current MTIA memory occupied by tensors in bytes for a given device.
  29. Args:
  30. device (torch.device or int or str, optional): selected device. Returns
  31. statistic for the current device, given by :func:`~torch.mtia.current_device`,
  32. if :attr:`device` is ``None`` (default).
  33. """
  34. if not is_initialized():
  35. return 0
  36. return memory_stats(device).get("dram", 0).get("allocated_bytes", 0)
  37. def reset_peak_memory_stats(device: Device = None) -> None:
  38. r"""Reset the peak memory stats for a given device.
  39. Args:
  40. device (torch.device, str, or int, optional) selected device. Returns
  41. statistics for the current device, given by current_device(),
  42. if device is None (default).
  43. """
  44. if not is_initialized():
  45. return
  46. torch._C._mtia_resetPeakMemoryStats(_get_device_index(device, optional=True))
  47. __all__ = [
  48. "memory_stats",
  49. "max_memory_allocated",
  50. "memory_allocated",
  51. "reset_peak_memory_stats",
  52. ]