__init__.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # mypy: allow-untyped-defs
  2. r"""
  3. This package implements abstractions found in ``torch.cuda``
  4. to facilitate writing device-agnostic code.
  5. """
  6. from collections.abc import Mapping
  7. from contextlib import AbstractContextManager
  8. from functools import lru_cache
  9. from types import MappingProxyType
  10. from typing import Any
  11. import torch
  12. from .. import device as _device
  13. from . import amp
  14. __all__ = [
  15. "is_available",
  16. "is_initialized",
  17. "synchronize",
  18. "current_device",
  19. "current_stream",
  20. "stream",
  21. "set_device",
  22. "device_count",
  23. "Stream",
  24. "StreamContext",
  25. "Event",
  26. "get_capabilities",
  27. ]
  28. @lru_cache(None)
  29. def get_capabilities() -> Mapping[str, Any]:
  30. """
  31. Returns an immutable mapping of CPU capabilities detected at runtime.
  32. This function queries the CPU for supported instruction sets and features
  33. using cpuinfo. The result is cached after the first call for efficiency.
  34. The returned mapping contains architecture-specific capabilities:
  35. For x86/x86_64:
  36. - SSE family: sse, sse2, sse3, ssse3, sse4_1, sse4_2, sse4a
  37. - AVX family: avx, avx2, avx_vnni
  38. - AVX-512 family: avx512_f, avx512_cd, avx512_dq, avx512_bw, avx512_vl,
  39. avx512_ifma, avx512_vbmi, avx512_vbmi2, avx512_bitalg, avx512_vpopcntdq,
  40. avx512_vnni, avx512_bf16, avx512_fp16, avx512_vp2intersect,
  41. avx512_4vnniw, avx512_4fmaps
  42. - AVX10 family: avx10_1, avx10_2
  43. - AVX-VNNI-INT: avx_vnni_int8, avx_vnni_int16, avx_ne_convert
  44. - AMX: amx_bf16, amx_tile, amx_int8, amx_fp16
  45. - FMA: fma3, fma4
  46. - Other: f16c, bmi, bmi2, popcnt, lzcnt, aes, sha, clflush, clflushopt, clwb
  47. For ARM64:
  48. - SIMD: neon, fp16_arith, bf16, i8mm, dot
  49. - SVE: sve, sve2, sve_bf16, sve_max_length (when supported)
  50. - SME: sme, sme2, sme_max_length (when supported)
  51. - Other: atomics, fhm, rdm, crc32, aes, sha1, sha2, pmull
  52. Common to all architectures:
  53. - architecture: string identifying the CPU architecture
  54. Returns:
  55. MappingProxyType: An immutable mapping where keys are capability names
  56. (e.g., 'avx2', 'sve') and values are booleans indicating
  57. support, or integers for properties like vector lengths.
  58. Example:
  59. >>> caps = torch.cpu.get_capabilities()
  60. >>> if caps.get("avx2", False):
  61. ... print("AVX2 is supported")
  62. >>> print(f"Architecture: {caps['architecture']}")
  63. """
  64. return MappingProxyType(torch._C._cpu._get_cpu_capability())
  65. def _is_avx2_supported() -> bool:
  66. r"""Returns a bool indicating if CPU supports AVX2."""
  67. return get_capabilities().get("avx2", False)
  68. def _is_avx512_supported() -> bool:
  69. r"""Returns a bool indicating if CPU supports AVX512."""
  70. return get_capabilities().get("avx512_f", False)
  71. def _is_avx512_bf16_supported() -> bool:
  72. r"""Returns a bool indicating if CPU supports AVX512_BF16."""
  73. return get_capabilities().get("avx512_bf16", False)
  74. def _is_vnni_supported() -> bool:
  75. r"""Returns a bool indicating if CPU supports VNNI."""
  76. # Note: Currently, it only checks avx512_vnni, will add the support of avx2_vnni later.
  77. return get_capabilities().get("avx512_vnni", False)
  78. def _is_amx_tile_supported() -> bool:
  79. r"""Returns a bool indicating if CPU supports AMX_TILE."""
  80. return get_capabilities().get("amx_tile", False)
  81. def _is_amx_fp16_supported() -> bool:
  82. r"""Returns a bool indicating if CPU supports AMX FP16."""
  83. return get_capabilities().get("amx_fp16", False)
  84. def _init_amx() -> bool:
  85. r"""Initializes AMX instructions."""
  86. return torch._C._cpu._init_amx()
  87. def is_available() -> bool:
  88. r"""Returns a bool indicating if CPU is currently available.
  89. N.B. This function only exists to facilitate device-agnostic code
  90. """
  91. return True
  92. def synchronize(device: torch.types.Device = None) -> None:
  93. r"""Waits for all kernels in all streams on the CPU device to complete.
  94. Args:
  95. device (torch.device or int, optional): ignored, there's only one CPU device.
  96. N.B. This function only exists to facilitate device-agnostic code.
  97. """
  98. class Stream:
  99. """
  100. N.B. This class only exists to facilitate device-agnostic code
  101. """
  102. def __init__(self, priority: int = -1) -> None:
  103. pass
  104. def wait_stream(self, stream) -> None:
  105. pass
  106. def record_event(self) -> None:
  107. pass
  108. def wait_event(self, event) -> None:
  109. pass
  110. class Event:
  111. def query(self) -> bool:
  112. return True
  113. def record(self, stream=None) -> None:
  114. pass
  115. def synchronize(self) -> None:
  116. pass
  117. def wait(self, stream=None) -> None:
  118. pass
  119. _default_cpu_stream = Stream()
  120. _current_stream = _default_cpu_stream
  121. def current_stream(device: torch.types.Device = None) -> Stream:
  122. r"""Returns the currently selected :class:`Stream` for a given device.
  123. Args:
  124. device (torch.device or int, optional): Ignored.
  125. N.B. This function only exists to facilitate device-agnostic code
  126. """
  127. return _current_stream
  128. class StreamContext(AbstractContextManager):
  129. r"""Context-manager that selects a given stream.
  130. N.B. This class only exists to facilitate device-agnostic code
  131. """
  132. cur_stream: Stream | None
  133. def __init__(self, stream):
  134. self.stream = stream
  135. self.prev_stream = _default_cpu_stream
  136. def __enter__(self):
  137. cur_stream = self.stream
  138. if cur_stream is None:
  139. return
  140. global _current_stream
  141. self.prev_stream = _current_stream
  142. _current_stream = cur_stream
  143. def __exit__(self, type: Any, value: Any, traceback: Any) -> None:
  144. cur_stream = self.stream
  145. if cur_stream is None:
  146. return
  147. global _current_stream
  148. _current_stream = self.prev_stream
  149. def stream(stream: Stream) -> AbstractContextManager:
  150. r"""Wrapper around the Context-manager StreamContext that
  151. selects a given stream.
  152. N.B. This function only exists to facilitate device-agnostic code
  153. """
  154. return StreamContext(stream)
  155. def device_count() -> int:
  156. r"""Returns number of CPU devices (not cores). Always 1.
  157. N.B. This function only exists to facilitate device-agnostic code
  158. """
  159. return 1
  160. def set_device(device: torch.types.Device) -> None:
  161. r"""Sets the current device, in CPU we do nothing.
  162. N.B. This function only exists to facilitate device-agnostic code
  163. """
  164. def current_device() -> str:
  165. r"""Returns current device for cpu. Always 'cpu'.
  166. N.B. This function only exists to facilitate device-agnostic code
  167. """
  168. return "cpu"
  169. def is_initialized() -> bool:
  170. r"""Returns True if the CPU is initialized. Always True.
  171. N.B. This function only exists to facilitate device-agnostic code
  172. """
  173. return True