__init__.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # mypy: allow-untyped-defs
  2. # pylint: disable=useless-parent-delegation
  3. from __future__ import annotations
  4. from typing import cast, Generic, Optional, TYPE_CHECKING, TypeVar, Union
  5. import torch
  6. if TYPE_CHECKING:
  7. from collections.abc import Callable
  8. __all__ = ["Future", "collect_all", "wait_all"]
  9. T = TypeVar("T")
  10. S = TypeVar("S")
  11. class Future(torch._C.Future, Generic[T]):
  12. r"""
  13. Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous
  14. execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It
  15. also exposes a set of APIs to add callback functions and set results.
  16. .. warning:: GPU support is a beta feature, subject to changes.
  17. """
  18. def __init__(
  19. self, *, devices: Optional[list[Union[int, str, torch.device]]] = None
  20. ):
  21. r"""
  22. Create an empty unset ``Future``. If the future is intended to hold
  23. values containing CUDA tensors, (a superset of) their CUDA devices must
  24. be specified at construction. (This is only supported if
  25. ``torch.cuda.is_available()`` returns ``True``). This is needed to
  26. ensure proper CUDA stream synchronization. The child futures, returned
  27. by the ``then`` method, will inherit these devices.
  28. Args:
  29. devices(``List[Union[int, str, torch.device]]``, optional): the set
  30. of devices on which tensors contained in this future's value are
  31. allowed to reside and on which callbacks are allowed to operate.
  32. """
  33. if devices is None:
  34. devices = []
  35. super().__init__([torch.device(d) for d in devices])
  36. def done(self) -> bool:
  37. r"""
  38. Return ``True`` if this ``Future`` is done. A ``Future`` is done if it
  39. has a result or an exception.
  40. If the value contains tensors that reside on GPUs, ``Future.done()``
  41. will return ``True`` even if the asynchronous kernels that are
  42. populating those tensors haven't yet completed running on the device,
  43. because at such stage the result is already usable, provided one
  44. performs the appropriate synchronizations (see :meth:`wait`).
  45. """
  46. return super().done()
  47. def wait(self) -> T:
  48. r"""
  49. Block until the value of this ``Future`` is ready.
  50. If the value contains tensors that reside on GPUs, then an additional
  51. synchronization is performed with the kernels (executing on the device)
  52. which may be asynchronously populating those tensors. Such sync is
  53. non-blocking, which means that ``wait()`` will insert the necessary
  54. instructions in the current streams to ensure that further operations
  55. enqueued on those streams will be properly scheduled after the async
  56. kernels but, once that is done, ``wait()`` will return, even if those
  57. kernels are still running. No further synchronization is required when
  58. accessing and using the values, as long as one doesn't change streams.
  59. Returns:
  60. The value held by this ``Future``. If the function (callback or RPC)
  61. creating the value has thrown an error, this ``wait`` method will
  62. also throw an error.
  63. """
  64. return super().wait()
  65. def value(self) -> T:
  66. r"""
  67. Obtain the value of an already-completed future.
  68. This method should only be called after a call to :meth:`wait` has
  69. completed, or inside a callback function passed to :meth:`then`. In
  70. other cases this ``Future`` may not yet hold a value and calling
  71. ``value()`` could fail.
  72. If the value contains tensors that reside on GPUs, then this method will
  73. *not* perform any additional synchronization. This should be done
  74. beforehand, separately, through a call to :meth:`wait` (except within
  75. callbacks, for which it's already being taken care of by :meth:`then`).
  76. Returns:
  77. The value held by this ``Future``. If the function (callback or RPC)
  78. creating the value has thrown an error, this ``value()`` method will
  79. also throw an error.
  80. """
  81. return super().value()
  82. def then(self, callback: Callable[[Future[T]], S]) -> Future[S]:
  83. r"""
  84. Append the given callback function to this ``Future``, which will be run
  85. when the ``Future`` is completed. Multiple callbacks can be added to
  86. the same ``Future``, but the order in which they will be executed cannot
  87. be guaranteed (to enforce a certain order consider chaining:
  88. ``fut.then(cb1).then(cb2)``). The callback must take one argument, which
  89. is the reference to this ``Future``. The callback function can use the
  90. :meth:`value` method to get the value. Note that if this ``Future`` is
  91. already completed, the given callback will be run immediately inline.
  92. If the ``Future``'s value contains tensors that reside on GPUs, the
  93. callback might be invoked while the async kernels that are populating
  94. those tensors haven't yet finished executing on the device. However, the
  95. callback will be invoked with some dedicated streams set as current
  96. (fetched from a global pool) which will be synchronized with those
  97. kernels. Hence any operation performed by the callback on these tensors
  98. will be scheduled on the device after the kernels complete. In other
  99. words, as long as the callback doesn't switch streams, it can safely
  100. manipulate the result without any additional synchronization. This is
  101. similar to the non-blocking behavior of :meth:`wait`.
  102. Similarly, if the callback returns a value that contains tensors that
  103. reside on a GPU, it can do so even if the kernels that are producing
  104. these tensors are still running on the device, as long as the callback
  105. didn't change streams during its execution. If one wants to change
  106. streams, one must be careful to re-synchronize them with the original
  107. streams, that is, those that were current when the callback was invoked.
  108. Args:
  109. callback(``Callable``): a ``Callable`` that takes this ``Future`` as
  110. the only argument.
  111. Returns:
  112. A new ``Future`` object that holds the return value of the
  113. ``callback`` and will be marked as completed when the given
  114. ``callback`` finishes.
  115. .. note:: Note that if the callback function throws, either
  116. through the original future being completed with an exception and
  117. calling ``fut.wait()``, or through other code in the callback, the
  118. future returned by ``then`` will be marked appropriately with the
  119. encountered error. However, if this callback later completes
  120. additional futures, those futures are not marked as completed with
  121. an error and the user is responsible for handling completion/waiting
  122. on those futures independently.
  123. Example::
  124. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
  125. >>> def callback(fut):
  126. ... print(f"RPC return value is {fut.wait()}.")
  127. >>> fut = torch.futures.Future()
  128. >>> # The inserted callback will print the return value when
  129. >>> # receiving the response from "worker1"
  130. >>> cb_fut = fut.then(callback)
  131. >>> chain_cb_fut = cb_fut.then(
  132. ... lambda x : print(f"Chained cb done. {x.wait()}")
  133. ... )
  134. >>> fut.set_result(5)
  135. RPC return value is 5.
  136. Chained cb done. None
  137. """
  138. return cast(Future[S], super().then(callback))
  139. def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None:
  140. r"""
  141. Append the given callback function to this ``Future``, which will be run
  142. when the ``Future`` is completed. Multiple callbacks can be added to
  143. the same ``Future``, but the order in which they will be executed cannot
  144. be guaranteed. The callback must take one argument, which is the
  145. reference to this ``Future``. The callback function can use the
  146. :meth:`value` method to get the value. Note that if this ``Future`` is
  147. already completed, the given callback will be run inline.
  148. We recommend that you use the :meth:`then` method as it provides a way
  149. to synchronize after your callback has completed. ``add_done_callback``
  150. can be cheaper if your callback does not return anything. But both
  151. :meth:`then` and ``add_done_callback`` use the same callback
  152. registration API under the hood.
  153. With respect to GPU tensors, this method behaves in the same way as
  154. :meth:`then`.
  155. Args:
  156. callback(``Future``): a ``Callable`` that takes in one argument,
  157. which is the reference to this ``Future``.
  158. .. note:: Note that if the callback function throws, either
  159. through the original future being completed with an exception and
  160. calling ``fut.wait()``, or through other code in the callback,
  161. error handling must be carefully taken care of. For example, if
  162. this callback later completes additional futures, those futures are
  163. not marked as completed with an error and the user is responsible
  164. for handling completion/waiting on those futures independently.
  165. Example::
  166. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
  167. >>> def callback(fut):
  168. ... print("This will run after the future has finished.")
  169. ... print(fut.wait())
  170. >>> fut = torch.futures.Future()
  171. >>> fut.add_done_callback(callback)
  172. >>> fut.set_result(5)
  173. This will run after the future has finished.
  174. 5
  175. """
  176. super().add_done_callback(callback)
  177. def set_result(self, result: T) -> None:
  178. r"""
  179. Set the result for this ``Future``, which will mark this ``Future`` as
  180. completed and trigger all attached callbacks. Note that a ``Future``
  181. cannot be marked completed twice.
  182. If the result contains tensors that reside on GPUs, this method can be
  183. called even if the asynchronous kernels that are populating those
  184. tensors haven't yet completed running on the device, provided that the
  185. streams on which those kernels were enqueued are set as the current ones
  186. when this method is called. Put simply, it's safe to call this method
  187. immediately after launching those kernels, without any additional
  188. synchronization, as long as one doesn't change streams in between. This
  189. method will record events on all the relevant current streams and will
  190. use them to ensure proper scheduling for all the consumers of this
  191. ``Future``.
  192. Args:
  193. result (object): the result object of this ``Future``.
  194. Example::
  195. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
  196. >>> import threading
  197. >>> import time
  198. >>> def slow_set_future(fut, value):
  199. ... time.sleep(0.5)
  200. ... fut.set_result(value)
  201. >>> fut = torch.futures.Future()
  202. >>> t = threading.Thread(
  203. ... target=slow_set_future,
  204. ... args=(fut, torch.ones(2) * 3)
  205. ... )
  206. >>> t.start()
  207. >>> print(fut.wait())
  208. tensor([3., 3.])
  209. >>> t.join()
  210. """
  211. super().set_result(result)
  212. def set_exception(self, result: T) -> None:
  213. r"""
  214. Set an exception for this ``Future``, which will mark this ``Future`` as
  215. completed with an error and trigger all attached callbacks. Note that
  216. when calling wait()/value() on this ``Future``, the exception set here
  217. will be raised inline.
  218. Args:
  219. result (BaseException): the exception for this ``Future``.
  220. Example::
  221. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
  222. >>> fut = torch.futures.Future()
  223. >>> fut.set_exception(ValueError("foo"))
  224. >>> fut.wait()
  225. Traceback (most recent call last):
  226. ...
  227. ValueError: foo
  228. """
  229. if not isinstance(result, Exception):
  230. raise AssertionError(
  231. f"{result} is of type {type(result)}, not an Exception."
  232. )
  233. def raise_error(fut_result):
  234. raise fut_result
  235. super()._set_unwrap_func(raise_error)
  236. self.set_result(result) # type: ignore[arg-type]
  237. def collect_all(futures: list[Future]) -> Future[list[Future]]:
  238. r"""
  239. Collects the provided :class:`~torch.futures.Future` objects into a single
  240. combined :class:`~torch.futures.Future` that is completed when all of the
  241. sub-futures are completed.
  242. Args:
  243. futures (list): a list of :class:`~torch.futures.Future` objects.
  244. Returns:
  245. Returns a :class:`~torch.futures.Future` object to a list of the passed
  246. in Futures.
  247. Example::
  248. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES)
  249. >>> fut0 = torch.futures.Future()
  250. >>> fut1 = torch.futures.Future()
  251. >>> fut = torch.futures.collect_all([fut0, fut1])
  252. >>> fut0.set_result(0)
  253. >>> fut1.set_result(1)
  254. >>> fut_list = fut.wait()
  255. >>> print(f"fut0 result = {fut_list[0].wait()}")
  256. fut0 result = 0
  257. >>> print(f"fut1 result = {fut_list[1].wait()}")
  258. fut1 result = 1
  259. """
  260. return cast(
  261. Future[list[Future]],
  262. torch._C._collect_all(cast(list[torch._C.Future], futures)),
  263. )
  264. def wait_all(futures: list[Future]) -> list:
  265. r"""
  266. Waits for all provided futures to be complete, and returns
  267. the list of completed values. If any of the futures encounters an error,
  268. the method will exit early and report the error not waiting for other
  269. futures to complete.
  270. Args:
  271. futures (list): a list of :class:`~torch.futures.Future` object.
  272. Returns:
  273. A list of the completed :class:`~torch.futures.Future` results. This
  274. method will throw an error if ``wait`` on any
  275. :class:`~torch.futures.Future` throws.
  276. """
  277. return [
  278. fut.wait()
  279. for fut in torch._C._collect_all(cast(list[torch._C.Future], futures)).wait()
  280. ]