device_interface.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. """
  2. Device abstraction layer for TorchDynamo and Inductor backends.
  3. This module provides a unified interface for different hardware backends (CUDA, XPU,
  4. CPU, MPS, MTIA) through a common device interface. Key components include:
  5. - DeviceInterface: Base class defining the common API for all device types
  6. - Device-specific implementations: CudaInterface, XpuInterface, CpuInterface, MpsInterface, MtiaInterface
  7. - Device registration system for managing available backends
  8. - Worker APIs for multi-processing scenarios
  9. - Stream and event management across different devices
  10. - Device property caching for worker processes
  11. The abstraction layer enables device-agnostic code in TorchDynamo while allowing
  12. specialized implementations for each hardware backend's unique features.
  13. """
  14. import inspect
  15. import time
  16. from collections import namedtuple
  17. from collections.abc import Callable, Iterable
  18. from dataclasses import dataclass
  19. from typing import Any, Literal, Optional, Union
  20. import torch
  21. get_cuda_stream: Optional[Callable[[int], int]]
  22. if torch.cuda._is_compiled():
  23. from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
  24. else:
  25. get_cuda_stream = None
  26. # Recording the device properties in the main process but used in worker process.
  27. caching_worker_device_properties: dict[str, Any] = {}
  28. caching_worker_current_devices: dict[str, int] = {}
  29. class DeviceInterface:
  30. """
  31. This is a simple device runtime interface for Inductor. It enables custom
  32. backends to be integrated with Inductor in a device-agnostic semantic.
  33. """
  34. class device:
  35. def __new__(cls, device: torch.types.Device) -> Any:
  36. raise NotImplementedError
  37. class Event:
  38. def __new__(cls, *args: Any, **kwargs: Any) -> Any:
  39. raise NotImplementedError(
  40. "Event should be inherited from torch.Event, otherwise, it couldn't be captured by dynamo."
  41. )
  42. class Stream:
  43. def __new__(cls, *args: Any, **kwargs: Any) -> Any:
  44. raise NotImplementedError(
  45. "Stream should be inherited from torch.Stream, otherwise, it couldn't be captured by dynamo."
  46. )
  47. class Worker:
  48. """
  49. Worker API to query device properties that will work in multi processing
  50. workers that cannot use the GPU APIs (due to processing fork() and
  51. initialization time issues). Properties are recorded in the main process
  52. before we fork the workers.
  53. """
  54. @staticmethod
  55. def set_device(device: int) -> None:
  56. raise NotImplementedError
  57. @staticmethod
  58. def current_device() -> int:
  59. raise NotImplementedError
  60. @staticmethod
  61. def get_device_properties(device: torch.types.Device = None) -> Any:
  62. raise NotImplementedError
  63. @staticmethod
  64. def current_device() -> int:
  65. raise NotImplementedError
  66. @staticmethod
  67. def set_device(device: torch.types.Device) -> None:
  68. raise NotImplementedError
  69. @staticmethod
  70. def maybe_exchange_device(device: int) -> int:
  71. raise NotImplementedError
  72. @staticmethod
  73. def exchange_device(device: int) -> int:
  74. raise NotImplementedError
  75. @staticmethod
  76. def device_count() -> int:
  77. raise NotImplementedError
  78. @staticmethod
  79. def is_available() -> bool:
  80. raise NotImplementedError
  81. @staticmethod
  82. def stream(stream: torch.Stream) -> Any:
  83. raise NotImplementedError
  84. @staticmethod
  85. def current_stream() -> torch.Stream:
  86. raise NotImplementedError
  87. @staticmethod
  88. def set_stream(stream: torch.Stream) -> None:
  89. raise NotImplementedError
  90. @staticmethod
  91. def _set_stream_by_id(stream_id: int, device_index: int, device_type: int) -> None:
  92. raise NotImplementedError
  93. @staticmethod
  94. def get_raw_stream(device_idx: int) -> int:
  95. raise NotImplementedError
  96. @staticmethod
  97. def synchronize(device: torch.types.Device = None) -> None:
  98. raise NotImplementedError
  99. @classmethod
  100. def get_device_properties(cls, device: torch.types.Device = None) -> Any:
  101. return cls.Worker.get_device_properties(device)
  102. @staticmethod
  103. def get_compute_capability(device: torch.types.Device = None) -> Any:
  104. raise NotImplementedError
  105. @staticmethod
  106. def is_bf16_supported(including_emulation: bool = False) -> bool:
  107. raise NotImplementedError
  108. @classmethod
  109. def is_dtype_supported(
  110. cls, dtype: torch.dtype, including_emulation: bool = False
  111. ) -> bool:
  112. return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation)
  113. @staticmethod
  114. def memory_allocated(device: torch.types.Device = None) -> int:
  115. raise NotImplementedError
  116. @staticmethod
  117. def is_triton_capable(device: torch.types.Device = None) -> bool:
  118. """
  119. Returns True if the device has Triton support, False otherwise, even if
  120. the appropriate Triton backend is not available.
  121. """
  122. return False
  123. @classmethod
  124. def raise_if_triton_unavailable(cls, device: torch.types.Device = None) -> None:
  125. """
  126. Raises a `RuntimeError` with the appropriate human-readable instructions
  127. to resolve the issue if Triton is not available for the given device, or
  128. the default device if `device` is `None`.
  129. The caller should ensure the presence of the 'triton' package before
  130. calling this method.
  131. """
  132. if not cls.is_triton_capable():
  133. raise RuntimeError("This device is not capable of supporting Triton")
  134. class DeviceGuard:
  135. """
  136. This class provides a context manager for device switching. This is a stripped
  137. down version of torch.{device_name}.device.
  138. The context manager changes the current device to the given device index
  139. on entering the context and restores the original device on exiting.
  140. The device is switched using the provided device interface.
  141. """
  142. def __init__(
  143. self, device_interface: type[DeviceInterface], index: Optional[int]
  144. ) -> None:
  145. self.device_interface = device_interface
  146. self.idx = index
  147. self.prev_idx = -1
  148. def __enter__(self) -> None:
  149. if self.idx is not None:
  150. self.prev_idx = self.device_interface.exchange_device(self.idx)
  151. def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]:
  152. if self.idx is not None:
  153. self.idx = self.device_interface.maybe_exchange_device(self.prev_idx)
  154. return False
  155. class CudaInterface(DeviceInterface):
  156. device = torch.cuda.device # type: ignore[assignment]
  157. # register Event and Stream class into the backend interface
  158. # make sure Event and Stream are implemented and inherited from the torch.Event and torch.Stream
  159. Event = torch.cuda.Event # type: ignore[assignment]
  160. Stream = torch.cuda.Stream # type: ignore[assignment]
  161. # pyrefly: ignore [bad-override]
  162. class Worker:
  163. @staticmethod
  164. def set_device(device: int) -> None:
  165. caching_worker_current_devices["cuda"] = device
  166. @staticmethod
  167. def current_device() -> int:
  168. if "cuda" in caching_worker_current_devices:
  169. return caching_worker_current_devices["cuda"]
  170. return torch.cuda.current_device()
  171. @staticmethod
  172. def get_device_properties(device: torch.types.Device = None) -> Any:
  173. if device is not None:
  174. if isinstance(device, str):
  175. device = torch.device(device)
  176. assert device.type == "cuda"
  177. if isinstance(device, torch.device):
  178. device = device.index
  179. if device is None:
  180. device = CudaInterface.Worker.current_device()
  181. if "cuda" not in caching_worker_device_properties:
  182. device_prop = [
  183. torch.cuda.get_device_properties(i)
  184. for i in range(torch.cuda.device_count())
  185. ]
  186. caching_worker_device_properties["cuda"] = device_prop
  187. return caching_worker_device_properties["cuda"][device]
  188. current_device = staticmethod(torch.cuda.current_device)
  189. set_device = staticmethod(torch.cuda.set_device)
  190. device_count = staticmethod(torch.cuda.device_count)
  191. stream = staticmethod(torch.cuda.stream) # type: ignore[assignment]
  192. # pyrefly: ignore [bad-override]
  193. current_stream = staticmethod(torch.cuda.current_stream)
  194. set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment]
  195. _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment]
  196. synchronize = staticmethod(torch.cuda.synchronize)
  197. get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment]
  198. get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type]
  199. exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type, has-type]
  200. maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type, has-type]
  201. memory_allocated = staticmethod(torch.cuda.memory_allocated)
  202. is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type]
  203. # Can be mock patched by @patch decorator.
  204. @staticmethod
  205. def is_available() -> bool:
  206. return torch.cuda.is_available()
  207. @staticmethod
  208. def get_compute_capability(device: torch.types.Device = None) -> Union[int, str]:
  209. if torch.version.hip is None:
  210. major, min = torch.cuda.get_device_capability(device)
  211. return major * 10 + min
  212. else:
  213. return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
  214. @staticmethod
  215. def is_triton_capable(device: torch.types.Device = None) -> bool:
  216. return (
  217. torch.version.hip is not None
  218. or torch.cuda.get_device_properties(device).major >= 7
  219. )
  220. @staticmethod
  221. def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
  222. from torch._inductor.exc import GPUTooOldForTriton
  223. if not CudaInterface.is_triton_capable(device):
  224. device_props = torch.cuda.get_device_properties(device)
  225. raise GPUTooOldForTriton(device_props, inspect.currentframe())
  226. import triton.backends
  227. if torch.version.hip is not None:
  228. if "amd" not in triton.backends.backends:
  229. raise RuntimeError("triton not built with the 'amd' backend")
  230. elif "nvidia" not in triton.backends.backends:
  231. raise RuntimeError("triton not built with the 'nvidia' backend")
  232. get_mtia_stream: Optional[Callable[[int], int]]
  233. if torch.mtia._is_compiled():
  234. from torch._C import _mtia_getCurrentRawStream as get_mtia_stream
  235. else:
  236. get_mtia_stream = None
  237. class MtiaInterface(DeviceInterface):
  238. device = torch.mtia.device # type: ignore[assignment]
  239. Event = torch.mtia.Event # type: ignore[assignment]
  240. Stream = torch.mtia.Stream # type: ignore[assignment]
  241. # pyrefly: ignore [bad-override]
  242. class Worker:
  243. @staticmethod
  244. def set_device(device: int) -> None:
  245. caching_worker_current_devices["mtia"] = device
  246. @staticmethod
  247. def current_device() -> int:
  248. if "mtia" in caching_worker_current_devices:
  249. return caching_worker_current_devices["mtia"]
  250. return torch.mtia.current_device()
  251. @staticmethod
  252. def get_device_properties(device: torch.types.Device = None) -> Any:
  253. if device is not None:
  254. if isinstance(device, str):
  255. device = torch.device(device)
  256. assert device.type == "mtia"
  257. if isinstance(device, torch.device):
  258. device = device.index
  259. if device is None:
  260. device = MtiaInterface.Worker.current_device()
  261. if "mtia" not in caching_worker_device_properties:
  262. device_prop = [
  263. torch.mtia.get_device_properties(i)
  264. for i in range(torch.mtia.device_count())
  265. ]
  266. caching_worker_device_properties["mtia"] = device_prop
  267. return caching_worker_device_properties["mtia"][device]
  268. current_device = staticmethod(torch.mtia.current_device)
  269. set_device = staticmethod(torch.mtia.set_device) # type: ignore[assignment]
  270. device_count = staticmethod(torch.mtia.device_count)
  271. stream = staticmethod(torch.mtia.stream) # type: ignore[assignment]
  272. # pyrefly: ignore [bad-override]
  273. current_stream = staticmethod(torch.mtia.current_stream)
  274. set_stream = staticmethod(torch.mtia.set_stream) # type: ignore[assignment]
  275. _set_stream_by_id = staticmethod(torch.mtia._set_stream_by_id) # type: ignore[assignment]
  276. synchronize = staticmethod(torch.mtia.synchronize)
  277. get_device_properties = staticmethod(torch.mtia.get_device_properties) # type: ignore[assignment]
  278. get_raw_stream = staticmethod(get_mtia_stream) # type: ignore[assignment, arg-type]
  279. exchange_device = staticmethod(torch.mtia._exchange_device) # type: ignore[arg-type, has-type]
  280. maybe_exchange_device = staticmethod(torch.mtia._maybe_exchange_device) # type: ignore[arg-type, has-type]
  281. memory_allocated = staticmethod(torch.mtia.memory_allocated) # type: ignore[assignment]
  282. is_bf16_supported = staticmethod(torch.mtia.is_bf16_supported) # type: ignore[arg-type]
  283. # Can be mock patched by @patch decorator.
  284. @staticmethod
  285. def is_available() -> bool:
  286. ret = torch.mtia.is_available()
  287. return ret
  288. @staticmethod
  289. def get_compute_capability(device: torch.types.Device = None) -> Any:
  290. cc = torch.mtia.get_device_capability(device)
  291. return cc
  292. @staticmethod
  293. def is_triton_capable(device: torch.types.Device = None) -> bool:
  294. return True
  295. @staticmethod
  296. def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
  297. import triton.backends
  298. if "mtia" not in triton.backends.backends:
  299. raise RuntimeError("triton not built with the 'mtia' backend")
  300. get_xpu_stream: Optional[Callable[[int], int]]
  301. if torch.xpu._is_compiled():
  302. from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
  303. else:
  304. get_xpu_stream = None
  305. class XpuInterface(DeviceInterface):
  306. device = torch.xpu.device # type: ignore[assignment]
  307. Event = torch.xpu.Event # type: ignore[assignment]
  308. Stream = torch.xpu.Stream # type: ignore[assignment]
  309. # pyrefly: ignore [bad-override]
  310. class Worker:
  311. @staticmethod
  312. def set_device(device: int) -> None:
  313. caching_worker_current_devices["xpu"] = device
  314. @staticmethod
  315. def current_device() -> int:
  316. if "xpu" in caching_worker_current_devices:
  317. return caching_worker_current_devices["xpu"]
  318. return torch.xpu.current_device()
  319. @staticmethod
  320. def get_device_properties(device: torch.types.Device = None) -> Any:
  321. if device is not None:
  322. if isinstance(device, str):
  323. device = torch.device(device)
  324. assert device.type == "xpu"
  325. if isinstance(device, torch.device):
  326. device = device.index
  327. if device is None:
  328. device = XpuInterface.Worker.current_device()
  329. if "xpu" not in caching_worker_device_properties:
  330. device_prop = [
  331. torch.xpu.get_device_properties(i)
  332. for i in range(torch.xpu.device_count())
  333. ]
  334. caching_worker_device_properties["xpu"] = device_prop
  335. return caching_worker_device_properties["xpu"][device]
  336. current_device = staticmethod(torch.xpu.current_device)
  337. set_device = staticmethod(torch.xpu.set_device)
  338. device_count = staticmethod(torch.xpu.device_count) # type: ignore[has-type]
  339. stream = staticmethod(torch.xpu.stream) # type: ignore[assignment]
  340. # pyrefly: ignore [bad-override]
  341. current_stream = staticmethod(torch.xpu.current_stream)
  342. set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment]
  343. _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment]
  344. synchronize = staticmethod(torch.xpu.synchronize)
  345. get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment]
  346. get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type]
  347. exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type, has-type]
  348. maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type, has-type]
  349. memory_allocated = staticmethod(torch.xpu.memory_allocated)
  350. # Can be mock patched by @patch decorator.
  351. @staticmethod
  352. def is_available() -> bool:
  353. return torch.xpu.is_available()
  354. @staticmethod
  355. def get_compute_capability(device: torch.types.Device = None) -> Any:
  356. cc = torch.xpu.get_device_capability(device)
  357. return cc
  358. @staticmethod
  359. def is_bf16_supported(including_emulation: bool = False) -> bool:
  360. return torch.xpu.is_bf16_supported()
  361. @staticmethod
  362. def is_triton_capable(device: torch.types.Device = None) -> bool:
  363. return True
  364. @staticmethod
  365. def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
  366. import triton.backends
  367. if "intel" not in triton.backends.backends:
  368. raise RuntimeError("triton not built with the 'intel' backend")
  369. @dataclass
  370. class CpuDeviceProperties:
  371. multi_processor_count: int
  372. class CpuInterface(DeviceInterface):
  373. # pyrefly: ignore [bad-override]
  374. class Event(torch.Event):
  375. def __init__(self, enable_timing: bool = True) -> None:
  376. self.time = 0.0
  377. def elapsed_time(self, other: Any) -> float:
  378. return (other.time - self.time) * 1000
  379. def record(self, stream: Any = None) -> None:
  380. self.time = time.perf_counter()
  381. # pyrefly: ignore [bad-override]
  382. class Worker:
  383. @staticmethod
  384. def get_device_properties(
  385. device: torch.types.Device = None,
  386. ) -> CpuDeviceProperties:
  387. import multiprocessing
  388. cpu_count = multiprocessing.cpu_count()
  389. return CpuDeviceProperties(cpu_count)
  390. @staticmethod
  391. def is_available() -> bool:
  392. return True
  393. @staticmethod
  394. def is_bf16_supported(including_emulation: bool = False) -> bool:
  395. return True
  396. @staticmethod
  397. def get_compute_capability(device: torch.types.Device = None) -> str:
  398. return ""
  399. @staticmethod
  400. def get_raw_stream(device_idx: Any) -> int:
  401. return 0
  402. @staticmethod
  403. def current_device() -> int:
  404. return 0
  405. @staticmethod
  406. def synchronize(device: torch.types.Device = None) -> None:
  407. pass
  408. @staticmethod
  409. def is_triton_capable(device: torch.types.Device = None) -> bool:
  410. return True
  411. @staticmethod
  412. def raise_if_triton_unavailable(device: torch.types.Device = None) -> None:
  413. import triton.backends
  414. if "cpu" not in triton.backends.backends:
  415. raise RuntimeError("triton not built with the 'cpu' backend")
  416. class MpsInterface(DeviceInterface):
  417. @staticmethod
  418. def is_bf16_supported(including_emulation: bool = False) -> bool:
  419. return torch.backends.mps.is_macos_or_newer(14, 0)
  420. @classmethod
  421. def is_dtype_supported(
  422. cls, dtype: torch.dtype, including_emulation: bool = False
  423. ) -> bool:
  424. if dtype in [torch.float64, torch.complex128]:
  425. return False
  426. return dtype != torch.bfloat16 or cls.is_bf16_supported(including_emulation)
  427. @staticmethod
  428. def is_available() -> bool:
  429. return torch.backends.mps.is_available()
  430. @staticmethod
  431. def current_device() -> int:
  432. return 0
  433. @staticmethod
  434. def get_compute_capability(device: torch.types.Device = None) -> str:
  435. return ""
  436. @staticmethod
  437. def synchronize(device: torch.types.Device = None) -> None:
  438. torch.mps.synchronize()
  439. # pyrefly: ignore [bad-override]
  440. class Worker:
  441. @staticmethod
  442. def get_device_properties(device: torch.types.Device = None) -> Any:
  443. return namedtuple("MPSProperties", ["multi_processor_count"])(
  444. torch.backends.mps.get_core_count() # type: ignore[arg-type]
  445. )
  446. @staticmethod
  447. def current_device() -> int:
  448. return 0
  449. device_interfaces: dict[str, type[DeviceInterface]] = {}
  450. _device_initialized = False
  451. def register_interface_for_device(
  452. device: Union[str, torch.device], device_interface: type[DeviceInterface]
  453. ) -> None:
  454. if isinstance(device, torch.device):
  455. device = device.type
  456. device_interfaces[device] = device_interface
  457. def get_interface_for_device(device: Union[str, torch.device]) -> type[DeviceInterface]:
  458. if isinstance(device, torch.device):
  459. device = device.type
  460. if not _device_initialized:
  461. init_device_reg()
  462. if device in device_interfaces:
  463. return device_interfaces[device]
  464. raise NotImplementedError(f"No interface for device {device}")
  465. def get_registered_device_interfaces() -> Iterable[tuple[str, type[DeviceInterface]]]:
  466. if not _device_initialized:
  467. init_device_reg()
  468. return device_interfaces.items()
  469. def init_device_reg() -> None:
  470. global _device_initialized
  471. register_interface_for_device("cuda", CudaInterface)
  472. for i in range(torch.cuda.device_count()):
  473. register_interface_for_device(f"cuda:{i}", CudaInterface)
  474. register_interface_for_device("xpu", XpuInterface)
  475. for i in range(torch.xpu.device_count()):
  476. register_interface_for_device(f"xpu:{i}", XpuInterface)
  477. register_interface_for_device("mtia", MtiaInterface)
  478. for i in range(torch.mtia.device_count()):
  479. register_interface_for_device(f"mtia:{i}", MtiaInterface)
  480. register_interface_for_device("cpu", CpuInterface)
  481. register_interface_for_device("mps", MpsInterface)
  482. _device_initialized = True