dlpack.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from typing import Any
  2. import torch
  3. import enum
  4. from torch._C import _to_dlpack as to_dlpack
  5. from torch.types import Device as _Device
  6. __all__ = [
  7. "DLDeviceType",
  8. "from_dlpack",
  9. ]
  10. class DLDeviceType(enum.IntEnum):
  11. # Enums as in DLPack specification (aten/src/ATen/dlpack.h)
  12. kDLCPU = 1,
  13. kDLCUDA = 2,
  14. kDLCUDAHost = 3,
  15. kDLOpenCL = 4,
  16. kDLVulkan = 7,
  17. kDLMetal = 8,
  18. kDLVPI = 9,
  19. kDLROCM = 10,
  20. kDLROCMHost = 11,
  21. kDLExtDev = 12,
  22. kDLCUDAManaged = 13,
  23. kDLOneAPI = 14,
  24. kDLWebGPU = 15,
  25. kDLHexagon = 16,
  26. kDLMAIA = 17,
  27. torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule
  28. Returns an opaque object (a "DLPack capsule") representing the tensor.
  29. .. note::
  30. ``to_dlpack`` is a legacy DLPack interface. The capsule it returns
  31. cannot be used for anything in Python other than use it as input to
  32. ``from_dlpack``. The more idiomatic use of DLPack is to call
  33. ``from_dlpack`` directly on the tensor object - this works when that
  34. object has a ``__dlpack__`` method, which PyTorch and most other
  35. libraries indeed have now.
  36. .. warning::
  37. Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``.
  38. Behavior when a capsule is consumed multiple times is undefined.
  39. Args:
  40. tensor: a tensor to be exported
  41. The DLPack capsule shares the tensor's memory.
  42. """)
  43. # TODO: add a typing.Protocol to be able to tell Mypy that only objects with
  44. # __dlpack__ and __dlpack_device__ methods are accepted.
  45. def from_dlpack(
  46. ext_tensor: Any,
  47. *,
  48. device: _Device | None = None,
  49. copy: bool | None = None
  50. ) -> 'torch.Tensor':
  51. """from_dlpack(ext_tensor) -> Tensor
  52. Converts a tensor from an external library into a ``torch.Tensor``.
  53. The returned PyTorch tensor will share the memory with the input tensor
  54. (which may have come from another library). Note that in-place operations
  55. will therefore also affect the data of the input tensor. This may lead to
  56. unexpected issues (e.g., other libraries may have read-only flags or
  57. immutable data structures), so the user should only do this if they know
  58. for sure that this is fine.
  59. Args:
  60. ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule):
  61. The tensor or DLPack capsule to convert.
  62. If ``ext_tensor`` is a tensor (or ndarray) object, it must support
  63. the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__``
  64. method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is
  65. an opaque ``PyCapsule`` instance, typically produced by a
  66. ``to_dlpack`` function or method.
  67. device (torch.device or str or None): An optional PyTorch device
  68. specifying where to place the new tensor. If None (default), the
  69. new tensor will be on the same device as ``ext_tensor``.
  70. copy (bool or None): An optional boolean indicating whether or not to copy
  71. ``self``. If None, PyTorch will copy only if necessary.
  72. Examples::
  73. >>> import torch.utils.dlpack
  74. >>> t = torch.arange(4)
  75. # Convert a tensor directly (supported in PyTorch >= 1.10)
  76. >>> t2 = torch.from_dlpack(t)
  77. >>> t2[:2] = -1 # show that memory is shared
  78. >>> t2
  79. tensor([-1, -1, 2, 3])
  80. >>> t
  81. tensor([-1, -1, 2, 3])
  82. # The old-style DLPack usage, with an intermediate capsule object
  83. >>> capsule = torch.utils.dlpack.to_dlpack(t)
  84. >>> capsule
  85. <capsule object "dltensor" at ...>
  86. >>> t3 = torch.from_dlpack(capsule)
  87. >>> t3
  88. tensor([-1, -1, 2, 3])
  89. >>> t3[0] = -9 # now we're sharing memory between 3 tensors
  90. >>> t3
  91. tensor([-9, -1, 2, 3])
  92. >>> t2
  93. tensor([-9, -1, 2, 3])
  94. >>> t
  95. tensor([-9, -1, 2, 3])
  96. """
  97. if hasattr(ext_tensor, '__dlpack__'):
  98. # Only populate kwargs if any of the optional arguments are, in fact, not None. Otherwise,
  99. # leave them out, since we might end up falling back to no-extra-kwargs __dlpack__ call.
  100. kwargs: dict[str, Any] = {}
  101. kwargs["max_version"] = (1, 0)
  102. # Track copy request for potential manual handling
  103. requested_copy = copy
  104. producer_handled_copy = True
  105. cross_device_transfer = False # Will be set to True if device transfer is needed
  106. if copy is not None:
  107. kwargs["copy"] = copy
  108. # Parse the device parameter.
  109. # At this moment, it can either be a torch.device or a str representing
  110. # a torch.device, e.g. "cpu", "cuda", etc.
  111. # Get source device first (we need it to detect cross-device transfers)
  112. ext_device = ext_tensor.__dlpack_device__()
  113. if device is not None:
  114. if isinstance(device, str):
  115. device = torch.device(device)
  116. if not isinstance(device, torch.device):
  117. raise AssertionError(f"from_dlpack: unsupported device type: {type(device)}")
  118. # Convert target device to DLPack format
  119. target_dl_device = torch._C._torchDeviceToDLDevice(device)
  120. # Detect cross-device transfer by comparing source and target devices
  121. # E.g. CPU->CUDA, cuda:0->cuda:1, etc.
  122. cross_device_transfer = (ext_device != target_dl_device)
  123. # Only pass dl_device to producer if NOT cross-device transfer
  124. if not cross_device_transfer:
  125. kwargs["dl_device"] = target_dl_device
  126. # Cross-device transfer always requires a copy
  127. if cross_device_transfer and copy is False:
  128. raise ValueError(
  129. f"cannot move DLPack tensor from device {ext_device} to {target_dl_device} "
  130. "without copying. Set copy=None or copy=True."
  131. )
  132. # ext_device is either CUDA or ROCm, we need to pass the current
  133. # stream
  134. if ext_device[0] in (DLDeviceType.kDLCUDA, DLDeviceType.kDLROCM):
  135. stream = torch.cuda.current_stream(f'cuda:{ext_device[1]}')
  136. # cuda_stream is the pointer to the stream and it is a public
  137. # attribute, but it is not documented
  138. # The array API specify that the default legacy stream must be passed
  139. # with a value of 1 for CUDA
  140. # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none
  141. is_cuda = ext_device[0] == DLDeviceType.kDLCUDA
  142. # Since pytorch is not using PTDS by default, lets directly pass
  143. # the legacy stream
  144. stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream
  145. kwargs["stream"] = stream_ptr
  146. # Try different parameter combinations until one works
  147. dlpack = None
  148. # Attempt 1: Try with all the parameters
  149. try:
  150. dlpack = ext_tensor.__dlpack__(**kwargs)
  151. except TypeError:
  152. pass
  153. # Attempt 2: Remove max_version
  154. if dlpack is None:
  155. kwargs.pop("max_version", None)
  156. try:
  157. dlpack = ext_tensor.__dlpack__(**kwargs)
  158. except TypeError:
  159. pass
  160. # Attempt 3: Remove copy
  161. if dlpack is None:
  162. kwargs.pop("copy", None)
  163. producer_handled_copy = False
  164. try:
  165. dlpack = ext_tensor.__dlpack__(**kwargs)
  166. except TypeError:
  167. pass
  168. # Attempt 4: Remove dl_device
  169. if dlpack is None:
  170. kwargs.pop("dl_device", None)
  171. dlpack = ext_tensor.__dlpack__(**kwargs)
  172. tensor = torch._C._from_dlpack(dlpack)
  173. # Manual copy if producer didn't handle it (cross-device already copies via .to())
  174. if requested_copy is True and not producer_handled_copy and not cross_device_transfer:
  175. tensor = tensor.clone()
  176. # Handle cross-device transfer by moving tensor to target device
  177. if cross_device_transfer:
  178. tensor = tensor.to(device)
  179. return tensor
  180. else:
  181. if device is not None or copy is not None:
  182. raise AssertionError(
  183. "device and copy kwargs not supported when ext_tensor is already a DLPack capsule."
  184. )
  185. # Old versions just call the converter
  186. dlpack = ext_tensor
  187. return torch._C._from_dlpack(dlpack)