_device.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from typing import Optional
  4. import torch
  5. from torch._C import _len_torch_function_stack
  6. from torch.overrides import _pop_mode, _push_mode, TorchFunctionMode
  7. from torch.utils._contextlib import context_decorator
  8. CURRENT_DEVICE: torch.device | None = None
  9. @functools.lru_cache(1)
  10. def _device_constructors():
  11. return {
  12. # standard ones
  13. torch.empty,
  14. torch.empty_permuted,
  15. torch.empty_strided,
  16. torch.empty_quantized,
  17. torch.ones,
  18. torch.arange,
  19. torch.bartlett_window,
  20. torch.blackman_window,
  21. torch.eye,
  22. torch.fft.fftfreq,
  23. torch.fft.rfftfreq,
  24. torch.full,
  25. torch.hamming_window,
  26. torch.hann_window,
  27. torch.kaiser_window,
  28. torch.linspace,
  29. torch.logspace,
  30. torch.nested.nested_tensor,
  31. # This function doesn't actually take a device argument
  32. # torch.normal,
  33. torch.rand,
  34. torch.randn,
  35. torch.randint,
  36. torch.randperm,
  37. torch.range,
  38. torch.sparse_coo_tensor,
  39. torch.sparse_compressed_tensor,
  40. torch.sparse_csr_tensor,
  41. torch.sparse_csc_tensor,
  42. torch.sparse_bsr_tensor,
  43. torch.sparse_bsc_tensor,
  44. torch.tril_indices,
  45. torch.triu_indices,
  46. torch.zeros,
  47. torch.asarray,
  48. # weird ones
  49. torch.tensor,
  50. torch.as_tensor,
  51. torch.scalar_tensor,
  52. }
  53. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  54. class DeviceContext(TorchFunctionMode):
  55. def __init__(self, device) -> None:
  56. # pyrefly: ignore [read-only]
  57. self.device = torch.device(device)
  58. self.prev_mode: Optional[DeviceContext] = None
  59. def __enter__(self):
  60. global CURRENT_DEVICE
  61. self.old_device = CURRENT_DEVICE
  62. CURRENT_DEVICE = self.device
  63. # We need to put the device at the bottom of the stack
  64. # If we set default device within a function mode context
  65. # exiting that context mode will pop the device function mode off
  66. # of the stack incorrectly
  67. cur_stack = [_pop_mode() for _ in range(_len_torch_function_stack())]
  68. _push_mode(self)
  69. for mode in reversed(cur_stack):
  70. if isinstance(mode, DeviceContext):
  71. self.prev_mode = mode
  72. else:
  73. _push_mode(mode)
  74. def __exit__(self, exc_type, exc_val, exc_tb):
  75. global CURRENT_DEVICE
  76. CURRENT_DEVICE = self.old_device
  77. cur_stack = []
  78. # Invariant: there should only be one DeviceContext on the stack at any time
  79. # (At the bottom), pop all modes until we hit the bottom, assert it's a DeviceContext
  80. # or else someone else has popped it!
  81. for _ in range(_len_torch_function_stack() - 1):
  82. mode = _pop_mode()
  83. if isinstance(mode, DeviceContext):
  84. raise AssertionError(
  85. "Found nested DeviceContext on the mode stack where none expected"
  86. )
  87. cur_stack.append(mode)
  88. if _len_torch_function_stack() > 0:
  89. mode = _pop_mode()
  90. if not isinstance(mode, DeviceContext):
  91. raise AssertionError(
  92. "Expected a DeviceContext at the bottom of the mode stack"
  93. )
  94. if self.prev_mode is not None:
  95. _push_mode(self.prev_mode)
  96. for mode in reversed(cur_stack):
  97. _push_mode(mode)
  98. def __torch_function__(self, func, types, args=(), kwargs=None):
  99. kwargs = kwargs or {}
  100. if func in _device_constructors() and kwargs.get("device") is None:
  101. kwargs["device"] = self.device
  102. return func(*args, **kwargs)
  103. # NB: This is directly called from C++ in torch/csrc/Device.cpp
  104. def device_decorator(device, func):
  105. return context_decorator(lambda: device, func)
  106. def set_device(device):
  107. """
  108. Set the default device inside of the wrapped function by decorating it with this function.
  109. If you would like to use this as a context manager, use device as a
  110. context manager directly, e.g., ``with torch.device(device)``.
  111. """
  112. return lambda func: device_decorator(torch.device(device), func)