__init__.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # mypy: allow-untyped-defs
  2. r"""
  3. This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python.
  4. Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased
  5. performance can be achieved, by running work on the metal GPU(s).
  6. See https://developer.apple.com/documentation/metalperformanceshaders for more details.
  7. """
  8. import torch
  9. from torch import Tensor
  10. _is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False)
  11. _default_mps_generator: torch._C.Generator = None # type: ignore[assignment]
  12. # local helper function (not public or exported)
  13. def _get_default_mps_generator() -> torch._C.Generator:
  14. global _default_mps_generator
  15. if _default_mps_generator is None:
  16. _default_mps_generator = torch._C._mps_get_default_generator()
  17. return _default_mps_generator
  18. def device_count() -> int:
  19. r"""Returns the number of available MPS devices."""
  20. return int(torch._C._has_mps and torch._C._mps_is_available())
  21. def synchronize() -> None:
  22. r"""Waits for all kernels in all streams on a MPS device to complete."""
  23. return torch._C._mps_deviceSynchronize()
  24. def get_rng_state(device: int | str | torch.device = "mps") -> Tensor:
  25. r"""Returns the random number generator state as a ByteTensor.
  26. Args:
  27. device (torch.device or int, optional): The device to return the RNG state of.
  28. Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
  29. """
  30. return _get_default_mps_generator().get_state()
  31. def set_rng_state(new_state: Tensor, device: int | str | torch.device = "mps") -> None:
  32. r"""Sets the random number generator state.
  33. Args:
  34. new_state (torch.ByteTensor): The desired state
  35. device (torch.device or int, optional): The device to set the RNG state.
  36. Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device).
  37. """
  38. new_state_copy = new_state.clone(memory_format=torch.contiguous_format)
  39. _get_default_mps_generator().set_state(new_state_copy)
  40. def manual_seed(seed: int) -> None:
  41. r"""Sets the seed for generating random numbers.
  42. Args:
  43. seed (int): The desired seed.
  44. """
  45. # the torch.mps.manual_seed() can be called from the global
  46. # torch.manual_seed() in torch/random.py. So we need to make
  47. # sure mps is available (otherwise we just return without
  48. # erroring out)
  49. if not torch._C._has_mps:
  50. return
  51. seed = int(seed)
  52. _get_default_mps_generator().manual_seed(seed)
  53. def seed() -> None:
  54. r"""Sets the seed for generating random numbers to a random number."""
  55. _get_default_mps_generator().seed()
  56. def empty_cache() -> None:
  57. r"""Releases all unoccupied cached memory currently held by the caching
  58. allocator so that those can be used in other GPU applications.
  59. """
  60. torch._C._mps_emptyCache()
  61. def set_per_process_memory_fraction(fraction) -> None:
  62. r"""Set memory fraction for limiting process's memory allocation on MPS device.
  63. The allowed value equals the fraction multiplied by recommended maximum device memory
  64. (obtained from Metal API device.recommendedMaxWorkingSetSize).
  65. If trying to allocate more than the allowed value in a process, it will raise an out of
  66. memory error in allocator.
  67. Args:
  68. fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction.
  69. .. note::
  70. Passing 0 to fraction means unlimited allocations
  71. (may cause system failure if out of memory).
  72. Passing fraction greater than 1.0 allows limits beyond the value
  73. returned from device.recommendedMaxWorkingSetSize.
  74. """
  75. if not isinstance(fraction, float):
  76. raise TypeError("Invalid type for fraction argument, must be `float`")
  77. if fraction < 0 or fraction > 2:
  78. raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2")
  79. torch._C._mps_setMemoryFraction(fraction)
  80. def current_allocated_memory() -> int:
  81. r"""Returns the current GPU memory occupied by tensors in bytes.
  82. .. note::
  83. The returned size does not include cached allocations in
  84. memory pools of MPSAllocator.
  85. """
  86. return torch._C._mps_currentAllocatedMemory()
  87. def driver_allocated_memory() -> int:
  88. r"""Returns total GPU memory allocated by Metal driver for the process in bytes.
  89. .. note::
  90. The returned size includes cached allocations in MPSAllocator pools
  91. as well as allocations from MPS/MPSGraph frameworks.
  92. """
  93. return torch._C._mps_driverAllocatedMemory()
  94. def recommended_max_memory() -> int:
  95. r"""Returns recommended max Working set size for GPU memory in bytes.
  96. .. note::
  97. Recommended max working set size for Metal.
  98. returned from device.recommendedMaxWorkingSetSize.
  99. """
  100. return torch._C._mps_recommendedMaxMemory()
  101. def compile_shader(source: str):
  102. r"""Compiles compute shader from source and allows one to invoke kernels
  103. defined there from the comfort of Python runtime
  104. Example::
  105. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_MPS)
  106. >>> lib = torch.mps.compile_shader(
  107. ... "kernel void full(device float* out, constant float& val, uint idx [[thread_position_in_grid]]) { out[idx] = val; }"
  108. ... )
  109. >>> x = torch.zeros(16, device="mps")
  110. >>> lib.full(x, 3.14)
  111. """
  112. from pathlib import Path
  113. from torch.utils._cpp_embed_headers import _embed_headers
  114. if not hasattr(torch._C, "_mps_compileShader"):
  115. raise RuntimeError("MPS is not available")
  116. source = _embed_headers(
  117. [l + "\n" for l in source.split("\n")],
  118. [Path(__file__).parent.parent / "include"],
  119. set(),
  120. )
  121. return torch._C._mps_compileShader(source)
  122. def is_available() -> bool:
  123. return device_count() > 0
  124. from . import profiler
  125. from .event import Event
  126. __all__ = [
  127. "compile_shader",
  128. "device_count",
  129. "get_rng_state",
  130. "manual_seed",
  131. "seed",
  132. "set_rng_state",
  133. "synchronize",
  134. "empty_cache",
  135. "set_per_process_memory_fraction",
  136. "current_allocated_memory",
  137. "driver_allocated_memory",
  138. "Event",
  139. "profiler",
  140. "recommended_max_memory",
  141. "is_available",
  142. ]