options.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # mypy: allow-untyped-defs
  2. from typing import Union
  3. import torch
  4. from . import _is_tensorpipe_available, constants as rpc_contants
  5. DeviceType = Union[int, str, torch.device]
  6. __all__ = ["TensorPipeRpcBackendOptions"]
  7. def _to_device(device: DeviceType) -> torch.device:
  8. device = torch.device(device)
  9. if device.type != "cuda":
  10. raise ValueError(
  11. "`set_devices` expect a list of CUDA devices, but got "
  12. f"device type {device.type}."
  13. )
  14. return device
  15. def _to_device_map(
  16. device_map: dict[DeviceType, DeviceType],
  17. ) -> dict[torch.device, torch.device]:
  18. full_device_map: dict[torch.device, torch.device] = {}
  19. reverse_map: dict[torch.device, torch.device] = {}
  20. for k, v in device_map.items():
  21. k, v = torch.device(k), torch.device(v)
  22. if v in reverse_map:
  23. raise ValueError(
  24. "`device_map` only supports 1-to-1 mapping, "
  25. f"trying to map {k} and {reverse_map[v]} to {v}"
  26. )
  27. full_device_map[k] = v
  28. reverse_map[v] = k
  29. return full_device_map
  30. def _to_device_list(devices: list[DeviceType]) -> list[torch.device]:
  31. return list(map(_to_device, devices))
  32. if _is_tensorpipe_available: # type: ignore[has-type]
  33. from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase
  34. else:
  35. _TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc]
  36. # pyrefly: ignore [invalid-inheritance]
  37. class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
  38. r"""
  39. The backend options for
  40. :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
  41. :class:`~torch.distributed.rpc.RpcBackendOptions`.
  42. Args:
  43. num_worker_threads (int, optional): The number of threads in the
  44. thread-pool used by
  45. :class:`~torch.distributed.rpc.TensorPipeAgent` to execute
  46. requests (default: 16).
  47. rpc_timeout (float, optional): The default timeout, in seconds,
  48. for RPC requests (default: 60 seconds). If the RPC has not
  49. completed in this timeframe, an exception indicating so will
  50. be raised. Callers can override this timeout for individual
  51. RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
  52. :meth:`~torch.distributed.rpc.rpc_async` if necessary.
  53. init_method (str, optional): The URL to initialize the distributed
  54. store used for rendezvous. It takes any value accepted for the
  55. same argument of :meth:`~torch.distributed.init_process_group`
  56. (default: ``env://``).
  57. device_maps (Dict[str, Dict], optional): Device placement mappings from
  58. this worker to the callee. Key is the callee worker name and value
  59. the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``)
  60. that maps this worker's devices to the callee worker's devices.
  61. (default: ``None``)
  62. devices (List[int, str, or ``torch.device``], optional): all local
  63. CUDA devices used by RPC agent. By Default, it will be initialized
  64. to all local devices from its own ``device_maps`` and corresponding
  65. devices from its peers' ``device_maps``. When processing CUDA RPC
  66. requests, the agent will properly synchronize CUDA streams for
  67. all devices in this ``List``.
  68. """
  69. def __init__(
  70. self,
  71. *,
  72. num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS,
  73. rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC,
  74. init_method: str = rpc_contants.DEFAULT_INIT_METHOD,
  75. device_maps: dict[str, dict[DeviceType, DeviceType]] | None = None,
  76. devices: list[DeviceType] | None = None,
  77. _transports: list | None = None,
  78. _channels: list | None = None,
  79. ):
  80. full_device_maps = (
  81. {}
  82. if device_maps is None
  83. else {k: _to_device_map(v) for k, v in device_maps.items()}
  84. )
  85. full_device_list = [] if devices is None else _to_device_list(devices)
  86. super().__init__(
  87. num_worker_threads,
  88. _transports,
  89. _channels,
  90. rpc_timeout,
  91. init_method,
  92. full_device_maps,
  93. full_device_list,
  94. )
  95. def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]):
  96. r"""
  97. Set device mapping between each RPC caller and callee pair. This
  98. function can be called multiple times to incrementally add
  99. device placement configurations.
  100. Args:
  101. to (str): Callee name.
  102. device_map (Dict of int, str, or torch.device): Device placement
  103. mappings from this worker to the callee. This map must be
  104. invertible.
  105. Example:
  106. >>> # xdoctest: +SKIP("distributed")
  107. >>> # both workers
  108. >>> def add(x, y):
  109. >>> print(x) # tensor([1., 1.], device='cuda:1')
  110. >>> return x + y, (x + y).to(2)
  111. >>>
  112. >>> # on worker 0
  113. >>> options = TensorPipeRpcBackendOptions(
  114. >>> num_worker_threads=8,
  115. >>> device_maps={"worker1": {0: 1}}
  116. >>> # maps worker0's cuda:0 to worker1's cuda:1
  117. >>> )
  118. >>> options.set_device_map("worker1", {1: 2})
  119. >>> # maps worker0's cuda:1 to worker1's cuda:2
  120. >>>
  121. >>> rpc.init_rpc(
  122. >>> "worker0",
  123. >>> rank=0,
  124. >>> world_size=2,
  125. >>> backend=rpc.BackendType.TENSORPIPE,
  126. >>> rpc_backend_options=options
  127. >>> )
  128. >>>
  129. >>> x = torch.ones(2)
  130. >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1))
  131. >>> # The first argument will be moved to cuda:1 on worker1. When
  132. >>> # sending the return value back, it will follow the invert of
  133. >>> # the device map, and hence will be moved back to cuda:0 and
  134. >>> # cuda:1 on worker0
  135. >>> print(rets[0]) # tensor([2., 2.], device='cuda:0')
  136. >>> print(rets[1]) # tensor([2., 2.], device='cuda:1')
  137. """
  138. full_device_map = _to_device_map(device_map)
  139. curr_device_maps = super().device_maps
  140. if to in curr_device_maps:
  141. for k, v in full_device_map.items():
  142. if k in curr_device_maps[to] and v != curr_device_maps[to][k]:
  143. raise ValueError(
  144. "`set_device_map` only supports 1-to-1 mapping, trying"
  145. f" to map {k} to {v} and {curr_device_maps[to][k]}"
  146. )
  147. super()._set_device_map(to, full_device_map)
  148. def set_devices(self, devices: list[DeviceType]):
  149. r"""
  150. Set local devices used by the TensorPipe RPC agent. When processing
  151. CUDA RPC requests, the TensorPipe RPC agent will properly synchronize
  152. CUDA streams for all devices in this ``List``.
  153. Args:
  154. devices (List of int, str, or torch.device): local devices used by
  155. the TensorPipe RPC agent.
  156. """
  157. self.devices = _to_device_list(devices)