functions.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. # mypy: allow-untyped-defs
  2. import functools
  3. def async_execution(fn):
  4. r"""
  5. A decorator for a function indicating that the return value of the function
  6. is guaranteed to be a :class:`~torch.futures.Future` object and this
  7. function can run asynchronously on the RPC callee. More specifically, the
  8. callee extracts the :class:`~torch.futures.Future` returned by the wrapped
  9. function and installs subsequent processing steps as a callback to that
  10. :class:`~torch.futures.Future`. The installed callback will read the value
  11. from the :class:`~torch.futures.Future` when completed and send the
  12. value back as the RPC response. That also means the returned
  13. :class:`~torch.futures.Future` only exists on the callee side and is never
  14. sent through RPC. This decorator is useful when the wrapped function's
  15. (``fn``) execution needs to pause and resume due to, e.g., containing
  16. :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals.
  17. .. note:: To enable asynchronous execution, applications must pass the
  18. function object returned by this decorator to RPC APIs. If RPC detected
  19. attributes installed by this decorator, it knows that this function
  20. returns a ``Future`` object and will handle that accordingly.
  21. However, this does not mean this decorator has to be outmost one when
  22. defining a function. For example, when combined with ``@staticmethod``
  23. or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the
  24. inner decorator to allow the target function be recognized as a static
  25. or class function. This target function can still execute asynchronously
  26. because, when accessed, the static or class method preserves attributes
  27. installed by ``@rpc.functions.async_execution``.
  28. Example::
  29. The returned :class:`~torch.futures.Future` object can come from
  30. :meth:`~torch.distributed.rpc.rpc_async`,
  31. :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future`
  32. constructor. The example below shows directly using the
  33. :class:`~torch.futures.Future` returned by
  34. :meth:`~torch.futures.Future.then`.
  35. >>> from torch.distributed import rpc
  36. >>>
  37. >>> # omitting setup and shutdown RPC
  38. >>>
  39. >>> # On all workers
  40. >>> @rpc.functions.async_execution
  41. >>> def async_add_chained(to, x, y, z):
  42. >>> # This function runs on "worker1" and returns immediately when
  43. >>> # the callback is installed through the `then(cb)` API. In the
  44. >>> # mean time, the `rpc_async` to "worker2" can run concurrently.
  45. >>> # When the return value of that `rpc_async` arrives at
  46. >>> # "worker1", "worker1" will run the lambda function accordingly
  47. >>> # and set the value for the previously returned `Future`, which
  48. >>> # will then trigger RPC to send the result back to "worker0".
  49. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  50. >>> lambda fut: fut.wait() + z
  51. >>> )
  52. >>>
  53. >>> # On worker0
  54. >>> # xdoctest: +SKIP
  55. >>> ret = rpc.rpc_sync(
  56. >>> "worker1",
  57. >>> async_add_chained,
  58. >>> args=("worker2", torch.ones(2), 1, 1)
  59. >>> )
  60. >>> print(ret) # prints tensor([3., 3.])
  61. When combined with TorchScript decorators, this decorator must be the
  62. outmost one.
  63. >>> from torch import Tensor
  64. >>> from torch.futures import Future
  65. >>> from torch.distributed import rpc
  66. >>>
  67. >>> # omitting setup and shutdown RPC
  68. >>>
  69. >>> # On all workers
  70. >>> @torch.jit.script
  71. >>> def script_add(x: Tensor, y: Tensor) -> Tensor:
  72. >>> return x + y
  73. >>>
  74. >>> @rpc.functions.async_execution
  75. >>> @torch.jit.script
  76. >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]:
  77. >>> return rpc.rpc_async(to, script_add, (x, y))
  78. >>>
  79. >>> # On worker0
  80. >>> ret = rpc.rpc_sync(
  81. >>> "worker1",
  82. >>> async_add,
  83. >>> args=("worker2", torch.ones(2), 1)
  84. >>> )
  85. >>> print(ret) # prints tensor([2., 2.])
  86. When combined with static or class method, this decorator must be the
  87. inner one.
  88. >>> from torch.distributed import rpc
  89. >>>
  90. >>> # omitting setup and shutdown RPC
  91. >>>
  92. >>> # On all workers
  93. >>> class AsyncExecutionClass:
  94. >>>
  95. >>> @staticmethod
  96. >>> @rpc.functions.async_execution
  97. >>> def static_async_add(to, x, y, z):
  98. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  99. >>> lambda fut: fut.wait() + z
  100. >>> )
  101. >>>
  102. >>> @classmethod
  103. >>> @rpc.functions.async_execution
  104. >>> def class_async_add(cls, to, x, y, z):
  105. >>> ret_fut = torch.futures.Future()
  106. >>> rpc.rpc_async(to, torch.add, args=(x, y)).then(
  107. >>> lambda fut: ret_fut.set_result(fut.wait() + z)
  108. >>> )
  109. >>> return ret_fut
  110. >>>
  111. >>> @rpc.functions.async_execution
  112. >>> def bound_async_add(self, to, x, y, z):
  113. >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then(
  114. >>> lambda fut: fut.wait() + z
  115. >>> )
  116. >>>
  117. >>> # On worker0
  118. >>> ret = rpc.rpc_sync(
  119. >>> "worker1",
  120. >>> AsyncExecutionClass.static_async_add,
  121. >>> args=("worker2", torch.ones(2), 1, 2)
  122. >>> )
  123. >>> print(ret) # prints tensor([4., 4.])
  124. >>>
  125. >>> ret = rpc.rpc_sync(
  126. >>> "worker1",
  127. >>> AsyncExecutionClass.class_async_add,
  128. >>> args=("worker2", torch.ones(2), 1, 2)
  129. >>> )
  130. >>> print(ret) # prints tensor([4., 4.])
  131. This decorator also works with RRef helpers, i.e., .
  132. :meth:`torch.distributed.rpc.RRef.rpc_sync`,
  133. :meth:`torch.distributed.rpc.RRef.rpc_async`, and
  134. :meth:`torch.distributed.rpc.RRef.remote`.
  135. >>> from torch.distributed import rpc
  136. >>>
  137. >>> # reuse the AsyncExecutionClass class above
  138. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  139. >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2)
  140. >>> print(ret) # prints tensor([4., 4.])
  141. >>>
  142. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  143. >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait()
  144. >>> print(ret) # prints tensor([4., 4.])
  145. >>>
  146. >>> rref = rpc.remote("worker1", AsyncExecutionClass)
  147. >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here()
  148. >>> print(ret) # prints tensor([4., 4.])
  149. """
  150. @functools.wraps(fn)
  151. def wrapper(*args, **kwargs):
  152. return fn(*args, **kwargs)
  153. # Can't declare and use attributes of function objects (mypy#2087)
  154. wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined]
  155. return wrapper