api.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965
  1. # mypy: allow-untyped-decorators
  2. # mypy: allow-untyped-defs
  3. import collections
  4. import contextlib
  5. import functools
  6. import inspect
  7. import logging
  8. import threading
  9. from typing import Any, Generic, TYPE_CHECKING, TypeVar
  10. import torch
  11. from torch._C._distributed_rpc import (
  12. _cleanup_python_rpc_handler,
  13. _delete_all_user_and_unforked_owner_rrefs,
  14. _destroy_rref_context,
  15. _get_current_rpc_agent,
  16. _invoke_remote_builtin,
  17. _invoke_remote_python_udf,
  18. _invoke_remote_torchscript,
  19. _invoke_rpc_builtin,
  20. _invoke_rpc_python_udf,
  21. _invoke_rpc_torchscript,
  22. _is_current_rpc_agent_set,
  23. _reset_current_rpc_agent,
  24. _set_and_start_rpc_agent,
  25. get_rpc_timeout,
  26. PyRRef,
  27. RemoteProfilerManager,
  28. WorkerInfo,
  29. )
  30. from torch.futures import Future
  31. from ._utils import _group_membership_management, _update_group_membership
  32. from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT
  33. from .internal import (
  34. _build_rpc_profiling_key,
  35. _internal_rpc_pickler,
  36. PythonUDF,
  37. RPCExecMode,
  38. )
  39. __all__ = [
  40. "shutdown",
  41. "get_worker_info",
  42. "remote",
  43. "rpc_sync",
  44. "rpc_async",
  45. "RRef",
  46. "AllGatherStates",
  47. "method_factory",
  48. "new_method",
  49. ]
  50. logger = logging.getLogger(__name__)
  51. # NB: Ignoring RRef leaks during shutdown. Without this, applications have to
  52. # make sure there is no references to any RRef in the application code and
  53. # Python GC has done its job to delete those RRefs. This is could result in bad
  54. # debugging experiences especially when for large applications. Therefore, by
  55. # default, we are going to ignore RRef leaks during shutdown. This is usually
  56. # fine as shutdown means applications have done training and no longer care
  57. # about states.
  58. #
  59. # To enable RRef leak checking, set this _ignore_rref_leak to False
  60. _ignore_rref_leak = True
  61. _default_pickler = _internal_rpc_pickler
  62. @contextlib.contextmanager
  63. def _use_rpc_pickler(rpc_pickler):
  64. r"""
  65. rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler
  66. """
  67. global _default_pickler
  68. _default_pickler = rpc_pickler
  69. try:
  70. yield
  71. finally:
  72. _default_pickler = _internal_rpc_pickler
  73. def _require_initialized(func):
  74. @functools.wraps(func)
  75. def wrapper(*args, **kwargs):
  76. if not _is_current_rpc_agent_set():
  77. raise RuntimeError(
  78. "RPC has not been initialized. Call "
  79. "torch.distributed.rpc.init_rpc first."
  80. )
  81. return func(*args, **kwargs)
  82. return wrapper
  83. class AllGatherStates:
  84. def __init__(self):
  85. # Each `gathered_objects` is an empty dict at beginning.
  86. # The leader worker is elected as the first worker in a sorted worker
  87. # name list. Whenever there is a worker entering `_all_gather()`, it
  88. # runs `_gather_to_leader()` on the leader to add its own name and
  89. # data obj to this dict. The leader also adds itself's name to the dict
  90. # on calling `_all_gather()`.
  91. # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader
  92. # will broadcast the gathered dict to all follower workers and set their
  93. # `gathered_objects` field and the `proceed_signal` field.
  94. self.gathered_objects = {}
  95. # All workers wait on this signal until it receives all gathered
  96. # objects.
  97. self.proceed_signal = threading.Event()
  98. # States used by `def _all_gather()`.
  99. # `_ALL_WORKER_NAMES` is initialized on initializing RPC layer.
  100. _ALL_WORKER_NAMES: set[Any] = set()
  101. _all_gather_dict_lock = threading.RLock()
  102. _all_gather_sequence_id: dict[str, int] = {}
  103. _all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict(
  104. AllGatherStates
  105. )
  106. def _init_rpc_states(agent):
  107. worker_infos = agent.get_worker_infos()
  108. global _ALL_WORKER_NAMES
  109. _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos}
  110. # NB: backend implementation might have already set the rpc_agent.
  111. if not _is_current_rpc_agent_set():
  112. _set_and_start_rpc_agent(agent)
  113. def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None):
  114. with _all_gather_dict_lock:
  115. if not worker_names:
  116. worker_names = _ALL_WORKER_NAMES
  117. assert worker_name in worker_names, (
  118. f"{worker_name} is not expected by leader."
  119. )
  120. states = _all_gather_sequence_id_to_states[sequence_id]
  121. assert worker_name not in states.gathered_objects, (
  122. f"{worker_name} reported intent sequence id {sequence_id} twice. "
  123. )
  124. states.gathered_objects[worker_name] = obj
  125. if worker_names == set(states.gathered_objects.keys()):
  126. states.proceed_signal.set()
  127. def _broadcast_to_followers(sequence_id, objects_map):
  128. with _all_gather_dict_lock:
  129. states = _all_gather_sequence_id_to_states[sequence_id]
  130. assert not states.proceed_signal.is_set(), (
  131. f"Termination signal sequence id {sequence_id} got set twice."
  132. )
  133. states.gathered_objects = objects_map
  134. states.proceed_signal.set()
  135. _thread_local_var = threading.local()
  136. @contextlib.contextmanager
  137. def _wait_all():
  138. r"""
  139. A context manager that collects all futures returned by ``rpc_async`` and
  140. waits them on the context manager's exit; relieving the user of needing
  141. to explicitly call wait.
  142. Example::
  143. >>> # xdoctest: +SKIP("distributed")
  144. >>> # On worker 0:
  145. >>> import torch
  146. >>> import torch.distributed.rpc as rpc
  147. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  148. >>> with rpc._wait_all():
  149. >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
  150. >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1))
  151. >>> #fut_1 and fut_2 are waited on
  152. """
  153. _thread_local_var.future_list = []
  154. try:
  155. yield
  156. finally:
  157. try:
  158. torch.futures.wait_all(_thread_local_var.future_list)
  159. finally:
  160. del _thread_local_var.future_list
  161. @_require_initialized
  162. def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT):
  163. r"""
  164. This is similar to torch.distributed.all_gather(), but is using RPC. It
  165. picks the worker with the smallest name (alphabetic order) as the leader.
  166. Then all followers send their data ``obj`` to the leader. After the leader
  167. has received all, it will broadcast the results back to all followers. This
  168. function blocks until all workers have received the gathered results.
  169. """
  170. if not worker_names:
  171. assert _ALL_WORKER_NAMES is not None, (
  172. "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`."
  173. )
  174. worker_names = _ALL_WORKER_NAMES
  175. leader_name = min(worker_names)
  176. self_name = _get_current_rpc_agent().get_worker_info().name
  177. with _all_gather_dict_lock:
  178. concat_names = "".join(sorted(worker_names))
  179. sequence_num = _all_gather_sequence_id.get(concat_names, 0)
  180. _all_gather_sequence_id[concat_names] = sequence_num + 1
  181. sequence_id = concat_names + str(sequence_num)
  182. is_leader = leader_name == self_name
  183. if timeout == UNSET_RPC_TIMEOUT:
  184. # Timeout is specified by agent for RPC calls
  185. rpc_timeout = get_rpc_timeout()
  186. # No timeout for signal
  187. signal_timeout = None
  188. elif timeout == DEFAULT_SHUTDOWN_TIMEOUT:
  189. # No timeout for RPC
  190. rpc_timeout = timeout
  191. # No timeout for signal
  192. signal_timeout = None
  193. else:
  194. # Signal and RPC timeout use the same timeout
  195. signal_timeout = rpc_timeout = timeout
  196. # Phase 1: Followers send it's object to the leader
  197. if is_leader:
  198. _gather_to_leader(sequence_id, self_name, obj, worker_names)
  199. else:
  200. rpc_sync(
  201. leader_name,
  202. _gather_to_leader,
  203. args=(sequence_id, self_name, obj, worker_names),
  204. timeout=rpc_timeout,
  205. )
  206. with _all_gather_dict_lock:
  207. states = _all_gather_sequence_id_to_states[sequence_id]
  208. # Timeout is either set by function parameter or None (which is indefinite)
  209. states.proceed_signal.wait(timeout=signal_timeout)
  210. # Phase 2: Leader broadcast gathered results to all followers
  211. # Leader's signal is the first to be unblocked, after receiving all
  212. # followers' data objects.
  213. if is_leader:
  214. worker_name_to_response_future_dict = {}
  215. for follower_name in worker_names - {leader_name}:
  216. fut = rpc_async(
  217. follower_name,
  218. _broadcast_to_followers,
  219. args=(sequence_id, states.gathered_objects),
  220. timeout=rpc_timeout,
  221. )
  222. worker_name_to_response_future_dict[follower_name] = fut
  223. errors = []
  224. for follower_name, fut in worker_name_to_response_future_dict.items():
  225. try:
  226. fut.wait()
  227. except RuntimeError as ex:
  228. errors.append((follower_name, ex))
  229. if errors:
  230. raise RuntimeError(
  231. f"Followers {[e[0] for e in errors]} timed out in _all_gather "
  232. f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}"
  233. )
  234. # Clean up for the states using the sequence_id
  235. with _all_gather_dict_lock:
  236. states = _all_gather_sequence_id_to_states.pop(sequence_id)
  237. return states.gathered_objects
  238. @_require_initialized
  239. def _barrier(worker_names):
  240. r"""
  241. Synchronizes local and remote RPC processes.
  242. This will block until all local and remote RPC processes specified under worker_names
  243. reach this method to wait for all outstanding work to complete.
  244. Args:
  245. worker_names (List[str]): The set of workers to synchronize.
  246. """
  247. try:
  248. _all_gather(None, set(worker_names))
  249. except RuntimeError:
  250. logger.exception("Failed to complete barrier")
  251. @_require_initialized
  252. def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT):
  253. r"""
  254. Block until all local and remote RPC processes reach this method and wait
  255. for all outstanding work to complete. Every RPC process must call this
  256. method before exit to perform a graceful shutdown. This should be used to
  257. terminate the RPC framework, and there is no guarantee that the RPC
  258. framework will work after this method returns.
  259. """
  260. try:
  261. _all_gather(None, timeout=timeout)
  262. except RuntimeError as ex:
  263. logger.exception("Failed to respond to 'Shutdown Proceed' in time")
  264. raise ex
  265. @_require_initialized
  266. def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT):
  267. r"""
  268. Perform a shutdown of the RPC agent, and then destroy the RPC agent. This
  269. stops the local agent from accepting outstanding requests, and shuts
  270. down the RPC framework by terminating all RPC threads. If ``graceful=True``,
  271. this will block until all local and remote RPC processes reach this method
  272. and wait for all outstanding work to complete. Otherwise, if
  273. ``graceful=False``, this is a local shutdown, and it does not wait for other
  274. RPC processes to reach this method.
  275. .. warning::
  276. For :class:`~torch.futures.Future` objects returned by
  277. :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not
  278. be called after ``shutdown()``.
  279. Args:
  280. graceful (bool): Whether to do a graceful shutdown or not. If True,
  281. this will 1) wait until there is no pending system
  282. messages for ``UserRRefs`` and delete them; 2) block
  283. until all local and remote RPC processes have reached
  284. this method and wait for all outstanding work to
  285. complete.
  286. Example::
  287. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  288. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  289. API for more details. For example,
  290. export MASTER_ADDR=localhost
  291. export MASTER_PORT=5678
  292. Then run the following code in two different processes:
  293. >>> # xdoctest: +SKIP
  294. >>> # On worker 0:
  295. >>> import torch
  296. >>> import torch.distributed.rpc as rpc
  297. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  298. >>> # do some work
  299. >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1))
  300. >>> # ready to shutdown
  301. >>> rpc.shutdown()
  302. >>> # On worker 1:
  303. >>> import torch.distributed.rpc as rpc
  304. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  305. >>> # wait for worker 0 to finish work, and then shutdown.
  306. >>> rpc.shutdown()
  307. """
  308. if graceful:
  309. try:
  310. agent = _get_current_rpc_agent()
  311. from torch._C._distributed_rpc import TensorPipeAgent
  312. if not isinstance(agent, TensorPipeAgent) or agent.is_static_group:
  313. _wait_all_workers(timeout)
  314. _delete_all_user_and_unforked_owner_rrefs()
  315. agent.join(shutdown=True, timeout=timeout)
  316. else:
  317. # This is a dynamic group so we need to grab the token for the operation
  318. my_worker_info = agent.get_worker_info()
  319. my_name = my_worker_info.name
  320. with _group_membership_management(agent.store, my_name, False):
  321. all_worker_infos = agent.get_worker_infos()
  322. for worker in all_worker_infos:
  323. if worker.name != my_name:
  324. rpc_sync(
  325. worker.name,
  326. _update_group_membership,
  327. args=(my_worker_info, [], {}, False),
  328. )
  329. agent.join(shutdown=True, timeout=timeout)
  330. finally:
  331. # In case of errors, continue to complete the local shutdown.
  332. _finalize_shutdown()
  333. else:
  334. _finalize_shutdown()
  335. def _finalize_shutdown():
  336. try:
  337. # This raises a `TORCH_CHECK()` exception on RRef leak detected.
  338. _destroy_rref_context(_ignore_rref_leak)
  339. finally:
  340. _get_current_rpc_agent().shutdown()
  341. # clean up python rpc handler in shutdown(), see comments in
  342. # PythonRpcHandler::cleanup(), call it in python API because the
  343. # cleanup() function has python dependency, it assumes python
  344. # interpreter exists.
  345. # No matter if RRef leak exception is raised, this clean-up code
  346. # must run to avoid destruction segfault in Python 3.5.
  347. #
  348. # future.wait() should not be called after shutdown().
  349. # pythonRpcHandler is cleaned up in shutdown(), after
  350. # shutdown(), python objects returned from rpc python call can not be
  351. # resolved.
  352. _cleanup_python_rpc_handler()
  353. _reset_current_rpc_agent()
  354. @_require_initialized
  355. def get_worker_info(worker_name=None):
  356. r"""
  357. Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name.
  358. Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an
  359. expensive string on every invocation.
  360. Args:
  361. worker_name (str): the string name of a worker. If ``None``, return the
  362. the id of the current worker. (default ``None``)
  363. Returns:
  364. :class:`~torch.distributed.rpc.WorkerInfo` instance for the given
  365. ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
  366. current worker if ``worker_name`` is ``None``.
  367. """
  368. if worker_name is not None:
  369. return _get_current_rpc_agent().get_worker_info(worker_name)
  370. else:
  371. return _get_current_rpc_agent().get_worker_info()
  372. def _to_worker_info(to):
  373. if isinstance(to, WorkerInfo):
  374. return to
  375. elif isinstance(to, (str, int)):
  376. return get_worker_info(to)
  377. else:
  378. raise ValueError(f"Cannot get WorkerInfo from name {to}")
  379. def _rref_typeof_on_owner(rref, blocking: bool = True):
  380. rref_type = type(rref.local_value())
  381. if blocking:
  382. return rref_type
  383. else:
  384. # Wrap result into a completed Future. This is so that if blocking=`False`
  385. # is specified, we return a future regardless of if this call is on user
  386. # or owner.
  387. future = Future[type]()
  388. future.set_result(rref_type)
  389. return future
  390. def _rref_typeof_on_user(
  391. rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True
  392. ):
  393. fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout)
  394. if blocking:
  395. return fut.wait()
  396. else:
  397. return fut
  398. T = TypeVar("T")
  399. # pyrefly: ignore [invalid-annotation]
  400. GenericWithOneTypeVar = Generic[T]
  401. if TYPE_CHECKING:
  402. class RRef(PyRRef[T], Generic[T]):
  403. pass
  404. else:
  405. try:
  406. # Combine the implementation class and the type class.
  407. class RRef(PyRRef, Generic[T]):
  408. pass
  409. except TypeError:
  410. # TypeError: metaclass conflict: the metaclass of a derived class
  411. # must be a (non-strict) subclass of the metaclasses of all its bases
  412. # Mypy doesn't understand __class__ (mypy bug #4177)
  413. class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type]
  414. pass
  415. # Combine the implementation class and the type class.
  416. # Types for classes expecting a certain generic parameter (mypy bug #7791)
  417. class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type]
  418. pass
  419. # Install docstrings from `PyRRef` to `RRef`.
  420. #
  421. # This is for the fact that pybind11 generates the parameter
  422. # `self` as type `rpc.PyRRef`, so a `:inherited-members:`
  423. # under `.. autoclass:: RRef` does not work.
  424. # we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`.
  425. #
  426. def method_factory(method_name, docstring):
  427. def method(self, *args, **kwargs):
  428. return getattr(super(RRef, self), method_name)(*args, **kwargs)
  429. if method.__doc__:
  430. method.__doc__ = docstring
  431. return method
  432. for method_name, method in inspect.getmembers(PyRRef):
  433. # Ignore magic methods, except "__str__".
  434. if method_name.startswith("_") and method_name != "__str__":
  435. continue
  436. # Get pybind11 generated docstring.
  437. # It's like,
  438. """
  439. to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object
  440. Blocking call that copies the value of the RRef from the owner
  441. to the local node and returns it. If the current node is the
  442. owner, returns a reference to the local value.
  443. """
  444. docstring = getattr(method, "__doc__", None)
  445. assert docstring is not None, "RRef user-facing methods should all have docstrings."
  446. # Do surgery on pybind11 generated docstrings.
  447. docstring = docstring.replace(
  448. "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef"
  449. )
  450. # Attach user-facing RRef method with modified docstring.
  451. new_method = method_factory(method_name, docstring)
  452. setattr(RRef, method_name, new_method)
  453. @_require_initialized
  454. def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
  455. r"""
  456. Make a remote call to run ``func`` on worker ``to`` and return an
  457. :class:`~torch.distributed.rpc.RRef` to the result value immediately.
  458. Worker ``to`` will be the owner of the returned
  459. :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is
  460. a user. The owner manages the global reference count of its
  461. :class:`~torch.distributed.rpc.RRef`, and the owner
  462. :class:`~torch.distributed.rpc.RRef` is only destructed when globally there
  463. are no living references to it.
  464. Args:
  465. to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
  466. func (Callable): a callable function, such as Python callables, builtin
  467. operators (e.g. :meth:`~torch.add`) and annotated
  468. TorchScript functions.
  469. args (tuple): the argument tuple for the ``func`` invocation.
  470. kwargs (dict): is a dictionary of keyword arguments for the ``func``
  471. invocation.
  472. timeout (float, optional): timeout in seconds for this remote call. If the
  473. creation of this
  474. :class:`~torch.distributed.rpc.RRef` on worker
  475. ``to`` is not successfully processed on this
  476. worker within this timeout, then the next time
  477. there is an attempt to use the RRef (such as
  478. ``to_here()``), a timeout will be raised
  479. indicating this failure. A value of 0 indicates
  480. an infinite timeout, i.e. a timeout error will
  481. never be raised. If not provided, the default
  482. value set during initialization or with
  483. ``_set_rpc_timeout`` is used.
  484. Returns:
  485. A user :class:`~torch.distributed.rpc.RRef` instance to the result
  486. value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here`
  487. to retrieve the result value locally.
  488. .. warning ::
  489. The ``remote`` API does not copy storages of argument tensors until
  490. sending them over the wire, which could be done by a different thread
  491. depending on the RPC backend type. The caller should make sure that the
  492. contents of those tensors stay intact until the returned RRef is
  493. confirmed by the owner, which can be checked using the
  494. :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API.
  495. .. warning ::
  496. Errors such as timeouts for the ``remote`` API are handled on a
  497. best-effort basis. This means that when remote calls initiated by
  498. ``remote`` fail, such as with a timeout error, we take a best-effort
  499. approach to error handling. This means that errors are handled and set
  500. on the resulting RRef on an asynchronous basis. If the RRef has not been
  501. used by the application before this handling (such as ``to_here`` or
  502. fork call), then future uses of the ``RRef`` will appropriately raise
  503. errors. However, it is possible that the user application will use the
  504. ``RRef`` before the errors are handled. In this case, errors may not be
  505. raised as they have not yet been handled.
  506. Example::
  507. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  508. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  509. API for more details. For example,
  510. export MASTER_ADDR=localhost
  511. export MASTER_PORT=5678
  512. Then run the following code in two different processes:
  513. >>> # xdoctest: +SKIP
  514. >>> # On worker 0:
  515. >>> import torch
  516. >>> import torch.distributed.rpc as rpc
  517. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  518. >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3))
  519. >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1))
  520. >>> x = rref1.to_here() + rref2.to_here()
  521. >>> rpc.shutdown()
  522. >>> # On worker 1:
  523. >>> import torch.distributed.rpc as rpc
  524. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  525. >>> rpc.shutdown()
  526. Below is an example of running a TorchScript function using RPC.
  527. >>> # On both workers:
  528. >>> @torch.jit.script
  529. >>> def my_script_add(tensor: torch.Tensor, scalar: int):
  530. >>> return torch.add(tensor, scalar)
  531. >>> # On worker 0:
  532. >>> import torch.distributed.rpc as rpc
  533. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  534. >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3))
  535. >>> rref.to_here()
  536. >>> rpc.shutdown()
  537. >>> # On worker 1:
  538. >>> import torch.distributed.rpc as rpc
  539. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  540. >>> rpc.shutdown()
  541. """
  542. torch._C._log_api_usage_once("torch.distributed.rpc_remote")
  543. qualified_name = torch.jit._builtins._find_builtin(func)
  544. dst_worker_info = _to_worker_info(to)
  545. should_profile = _get_should_profile()
  546. ctx_manager = _enable_rpc_profiler(
  547. should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info
  548. )
  549. with ctx_manager as rf:
  550. args = args if args else ()
  551. kwargs = kwargs if kwargs else {}
  552. is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
  553. if is_async_exec:
  554. wrapped = func._wrapped_async_rpc_function
  555. if isinstance(wrapped, torch.jit.ScriptFunction):
  556. func = wrapped
  557. if qualified_name is not None:
  558. rref = _invoke_remote_builtin(
  559. dst_worker_info, qualified_name, timeout, *args, **kwargs
  560. )
  561. elif isinstance(func, torch.jit.ScriptFunction):
  562. rref = _invoke_remote_torchscript(
  563. dst_worker_info.name,
  564. torch._jit_internal._qualified_name(func),
  565. timeout,
  566. is_async_exec,
  567. *args,
  568. **kwargs,
  569. )
  570. else:
  571. (pickled_python_udf, tensors) = _default_pickler.serialize(
  572. PythonUDF(func, args, kwargs)
  573. )
  574. rref = _invoke_remote_python_udf(
  575. dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec
  576. )
  577. # attach profiling information
  578. if should_profile:
  579. assert torch.autograd._profiler_enabled()
  580. assert rf is not None
  581. fut = rf._call_end_callbacks_on_future(rref._get_future())
  582. rref._set_profiling_future(fut)
  583. return rref
  584. def _invoke_rpc(
  585. to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT
  586. ):
  587. if not callable(func):
  588. raise TypeError("function should be callable.")
  589. qualified_name = torch.jit._builtins._find_builtin(func)
  590. dst_worker_info = _to_worker_info(to)
  591. should_profile = _get_should_profile()
  592. ctx_manager = _enable_rpc_profiler(
  593. should_profile, qualified_name, func, rpc_type, dst_worker_info
  594. )
  595. with ctx_manager as rf:
  596. args = args if args else ()
  597. kwargs = kwargs if kwargs else {}
  598. is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
  599. if is_async_exec:
  600. # pyrefly: ignore [missing-attribute]
  601. wrapped = func._wrapped_async_rpc_function
  602. if isinstance(wrapped, torch.jit.ScriptFunction):
  603. func = wrapped
  604. if qualified_name is not None:
  605. fut = _invoke_rpc_builtin(
  606. dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs
  607. )
  608. elif isinstance(func, torch.jit.ScriptFunction):
  609. fut = _invoke_rpc_torchscript(
  610. dst_worker_info.name,
  611. torch._jit_internal._qualified_name(func),
  612. args,
  613. kwargs,
  614. rpc_timeout,
  615. is_async_exec,
  616. )
  617. else:
  618. (pickled_python_udf, tensors) = _default_pickler.serialize(
  619. PythonUDF(func, args, kwargs)
  620. )
  621. fut = _invoke_rpc_python_udf(
  622. dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec
  623. )
  624. if should_profile:
  625. assert torch.autograd._profiler_enabled()
  626. assert rf is not None
  627. # Schedule profiling callbacks to run when the future completes.
  628. # This returns a future that is completed when the original future
  629. # completes and the profiling callbacks have been completed as well,
  630. # to guarantee that fut.wait() completes the profiling. This new
  631. # future will contain the same value as the original future.
  632. fut = rf._call_end_callbacks_on_future(fut)
  633. return fut
  634. @_require_initialized
  635. def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT):
  636. r"""
  637. Make a blocking RPC call to run function ``func`` on worker ``to``. RPC
  638. messages are sent and received in parallel to execution of Python code. This
  639. method is thread-safe.
  640. Args:
  641. to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
  642. func (Callable): a callable function, such as Python callables, builtin
  643. operators (e.g. :meth:`~torch.add`) and annotated
  644. TorchScript functions.
  645. args (tuple): the argument tuple for the ``func`` invocation.
  646. kwargs (dict): is a dictionary of keyword arguments for the ``func``
  647. invocation.
  648. timeout (float, optional): timeout in seconds to use for this RPC. If
  649. the RPC does not complete in this amount of
  650. time, an exception indicating it has
  651. timed out will be raised. A value of 0
  652. indicates an infinite timeout, i.e. a timeout
  653. error will never be raised. If not provided,
  654. the default value set during initialization
  655. or with ``_set_rpc_timeout`` is used.
  656. Returns:
  657. Returns the result of running ``func`` with ``args`` and ``kwargs``.
  658. Example::
  659. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  660. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  661. API for more details. For example,
  662. export MASTER_ADDR=localhost
  663. export MASTER_PORT=5678
  664. Then run the following code in two different processes:
  665. >>> # xdoctest: +SKIP
  666. >>> # On worker 0:
  667. >>> import torch
  668. >>> import torch.distributed.rpc as rpc
  669. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  670. >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3))
  671. >>> rpc.shutdown()
  672. >>> # On worker 1:
  673. >>> import torch.distributed.rpc as rpc
  674. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  675. >>> rpc.shutdown()
  676. Below is an example of running a TorchScript function using RPC.
  677. >>> # On both workers:
  678. >>> @torch.jit.script
  679. >>> def my_script_add(tensor: torch.Tensor, scalar: int):
  680. >>> return torch.add(tensor, scalar)
  681. >>> # On worker 0:
  682. >>> import torch.distributed.rpc as rpc
  683. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  684. >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
  685. >>> rpc.shutdown()
  686. >>> # On worker 1:
  687. >>> import torch.distributed.rpc as rpc
  688. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  689. >>> rpc.shutdown()
  690. """
  691. torch._C._log_api_usage_once("torch.distributed.rpc_sync")
  692. fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout)
  693. return fut.wait()
  694. @_require_initialized
  695. def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT):
  696. r"""
  697. Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC
  698. messages are sent and received in parallel to execution of Python code. This
  699. method is thread-safe. This method will immediately return a
  700. :class:`~torch.futures.Future` that can be awaited on.
  701. Args:
  702. to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker.
  703. func (Callable): a callable function, such as Python callables, builtin
  704. operators (e.g. :meth:`~torch.add`) and annotated
  705. TorchScript functions.
  706. args (tuple): the argument tuple for the ``func`` invocation.
  707. kwargs (dict): is a dictionary of keyword arguments for the ``func``
  708. invocation.
  709. timeout (float, optional): timeout in seconds to use for this RPC. If
  710. the RPC does not complete in this amount of
  711. time, an exception indicating it has
  712. timed out will be raised. A value of 0
  713. indicates an infinite timeout, i.e. a timeout
  714. error will never be raised. If not provided,
  715. the default value set during initialization
  716. or with ``_set_rpc_timeout`` is used.
  717. Returns:
  718. Returns a :class:`~torch.futures.Future` object that can be waited
  719. on. When completed, the return value of ``func`` on ``args`` and
  720. ``kwargs`` can be retrieved from the :class:`~torch.futures.Future`
  721. object.
  722. .. warning ::
  723. Using GPU tensors as arguments or return values of ``func`` is not
  724. supported since we don't support sending GPU tensors over the wire. You
  725. need to explicitly copy GPU tensors to CPU before using them as
  726. arguments or return values of ``func``.
  727. .. warning ::
  728. The ``rpc_async`` API does not copy storages of argument tensors until
  729. sending them over the wire, which could be done by a different thread
  730. depending on the RPC backend type. The caller should make sure that the
  731. contents of those tensors stay intact until the returned
  732. :class:`~torch.futures.Future` completes.
  733. Example::
  734. Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly
  735. on both workers. Refer to :meth:`~torch.distributed.init_process_group`
  736. API for more details. For example,
  737. export MASTER_ADDR=localhost
  738. export MASTER_PORT=5678
  739. Then run the following code in two different processes:
  740. >>> # xdoctest: +SKIP
  741. >>> # On worker 0:
  742. >>> import torch
  743. >>> import torch.distributed.rpc as rpc
  744. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  745. >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3))
  746. >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2))
  747. >>> result = fut1.wait() + fut2.wait()
  748. >>> rpc.shutdown()
  749. >>> # On worker 1:
  750. >>> import torch.distributed.rpc as rpc
  751. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  752. >>> rpc.shutdown()
  753. Below is an example of running a TorchScript function using RPC.
  754. >>> # On both workers:
  755. >>> @torch.jit.script
  756. >>> def my_script_add(tensor: torch.Tensor, scalar: int):
  757. >>> return torch.add(tensor, scalar)
  758. >>> # On worker 0:
  759. >>> import torch.distributed.rpc as rpc
  760. >>> rpc.init_rpc("worker0", rank=0, world_size=2)
  761. >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
  762. >>> ret = fut.wait()
  763. >>> rpc.shutdown()
  764. >>> # On worker 1:
  765. >>> import torch.distributed.rpc as rpc
  766. >>> rpc.init_rpc("worker1", rank=1, world_size=2)
  767. >>> rpc.shutdown()
  768. """
  769. torch._C._log_api_usage_once("torch.distributed.rpc_async")
  770. fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout)
  771. if hasattr(_thread_local_var, "future_list"):
  772. _thread_local_var.future_list.append(fut)
  773. return fut
  774. def _get_should_profile():
  775. # Legacy profiler should be enabled. RPC profiling is not supported with
  776. # Kineto profiler.
  777. ActiveProfilerType = torch._C._profiler.ActiveProfilerType
  778. return (
  779. torch.autograd._profiler_enabled()
  780. and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined]
  781. )
  782. def _enable_rpc_profiler(
  783. should_profile, qualified_name, func, rpc_type, dst_worker_info
  784. ):
  785. ctx_manager = contextlib.nullcontext()
  786. if should_profile:
  787. # Create appropriate string representation based on type of func
  788. # (builtin, script, python)
  789. if qualified_name is None:
  790. func_name = (
  791. torch._jit_internal._qualified_name(func)
  792. if isinstance(func, torch.jit.ScriptFunction)
  793. else func.__qualname__
  794. )
  795. else:
  796. func_name = qualified_name
  797. # Build RPC profiling key.
  798. rpc_profiling_key = _build_rpc_profiling_key(
  799. rpc_type,
  800. func_name,
  801. get_worker_info().name,
  802. dst_worker_info.name,
  803. )
  804. RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key)
  805. # Mypy doesn't support re-def of a variable not in the same block (#1174)
  806. ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment]
  807. return ctx_manager