backend_registry.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. # mypy: allow-untyped-defs
  2. import collections
  3. import enum
  4. from typing import cast
  5. import torch
  6. import torch.distributed as dist
  7. from . import api, constants as rpc_constants
  8. from ._utils import _group_membership_management, _update_group_membership
  9. __all__ = [
  10. "backend_registered",
  11. "register_backend",
  12. "construct_rpc_backend_options",
  13. "init_backend",
  14. "BackendValue",
  15. "BackendType",
  16. ]
  17. BackendValue = collections.namedtuple(
  18. "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"]
  19. )
  20. def _backend_type_repr(self):
  21. return "BackendType." + self.name
  22. _backend_type_doc = """
  23. An enum class of available backends.
  24. PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend.
  25. Additional ones can be registered using the
  26. :func:`~torch.distributed.rpc.backend_registry.register_backend` function.
  27. """
  28. # Create an enum type, `BackendType`, with empty members.
  29. # Can't handle Function Enum API (mypy bug #9079)
  30. BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc]
  31. # Unable to assign a function a method (mypy bug #2427)
  32. BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
  33. if BackendType.__doc__:
  34. BackendType.__doc__ = _backend_type_doc
  35. def backend_registered(backend_name):
  36. """
  37. Checks if backend_name is registered as an RPC backend.
  38. Args:
  39. backend_name (str): string to identify the RPC backend.
  40. Returns:
  41. True if the backend has been registered with ``register_backend``, else
  42. False.
  43. """
  44. return backend_name in BackendType.__members__
  45. def register_backend(
  46. backend_name, construct_rpc_backend_options_handler, init_backend_handler
  47. ):
  48. """Registers a new RPC backend.
  49. Args:
  50. backend_name (str): backend string to identify the handler.
  51. construct_rpc_backend_options_handler (function):
  52. Handler that is invoked when
  53. rpc_backend.construct_rpc_backend_options(**dict) is called.
  54. init_backend_handler (function): Handler that is invoked when the
  55. `_init_rpc_backend()` function is called with a backend.
  56. This returns the agent.
  57. """
  58. global BackendType
  59. if backend_registered(backend_name):
  60. raise RuntimeError(f"RPC backend {backend_name}: already registered")
  61. # Create a new enum type, `BackendType`, with extended members.
  62. existing_enum_dict = {member.name: member.value for member in BackendType}
  63. extended_enum_dict = dict(
  64. {
  65. backend_name: BackendValue(
  66. construct_rpc_backend_options_handler=construct_rpc_backend_options_handler,
  67. init_backend_handler=init_backend_handler,
  68. )
  69. },
  70. **existing_enum_dict,
  71. )
  72. # Can't handle Function Enum API (mypy bug #9079)
  73. BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc]
  74. # Unable to assign a function a method (mypy bug #2427)
  75. BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
  76. if BackendType.__doc__:
  77. BackendType.__doc__ = _backend_type_doc
  78. return BackendType[backend_name]
  79. def construct_rpc_backend_options(
  80. backend,
  81. rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC,
  82. init_method=rpc_constants.DEFAULT_INIT_METHOD,
  83. **kwargs,
  84. ):
  85. return backend.value.construct_rpc_backend_options_handler(
  86. rpc_timeout, init_method, **kwargs
  87. )
  88. def init_backend(backend, *args, **kwargs):
  89. return backend.value.init_backend_handler(*args, **kwargs)
  90. def _init_process_group(store, rank, world_size):
  91. # Initialize ProcessGroup.
  92. process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT
  93. # We're using a bunch of private APIs here since `new_group` requires the
  94. # default group to be initialized.
  95. group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout)
  96. assert group is not None, "Failed to initialize default ProcessGroup."
  97. if (rank != -1) and (rank != group.rank()):
  98. raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}")
  99. if (world_size != -1) and (world_size != group.size()):
  100. raise RuntimeError(
  101. f"world_size argument {world_size} doesn't match pg size {group.size()}"
  102. )
  103. return group
  104. def _tensorpipe_construct_rpc_backend_options_handler(
  105. rpc_timeout,
  106. init_method,
  107. num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS,
  108. _transports=None,
  109. _channels=None,
  110. **kwargs,
  111. ):
  112. from . import TensorPipeRpcBackendOptions
  113. return TensorPipeRpcBackendOptions(
  114. rpc_timeout=rpc_timeout,
  115. init_method=init_method,
  116. num_worker_threads=num_worker_threads,
  117. _transports=_transports,
  118. _channels=_channels,
  119. )
  120. def _tensorpipe_validate_devices(devices, device_count):
  121. return all(
  122. d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count)
  123. for d in devices
  124. )
  125. # detect if any worker has invalid device_map configurations, and return
  126. # reverse device maps
  127. def _tensorpipe_exchange_and_check_all_device_maps(
  128. my_name, my_device_count, my_device_maps, my_devices, group
  129. ):
  130. gathered: list[
  131. tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]]
  132. ] = [("", 0, {}, []) for _ in range(group.size())]
  133. dist.all_gather_object(
  134. gathered, (my_name, my_device_count, my_device_maps, my_devices), group
  135. )
  136. all_names = [name for name, _, _, _ in gathered]
  137. all_device_counts = {name: count for name, count, _, _ in gathered}
  138. all_device_maps = {name: map_ for name, _, map_, _ in gathered}
  139. all_devices = {name: devices for name, _, _, devices in gathered}
  140. _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices)
  141. # passed all checked, construct reverse mapping and get list of devices handled by this agent
  142. reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
  143. my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps)
  144. return reverse_device_maps, my_devices
  145. def _validate_device_maps(
  146. all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True
  147. ):
  148. for node in all_names:
  149. devices = all_devices[node]
  150. if len(set(devices)) != len(devices):
  151. raise ValueError(f"Node {node} has duplicated devices\ndevices = {devices}")
  152. if not _tensorpipe_validate_devices(devices, all_device_counts[node]):
  153. raise ValueError(
  154. f"Node {node} has devices with invalid indices\n"
  155. f"devices = {devices}\n"
  156. f"device count = {all_device_counts[node]}"
  157. )
  158. for source_node in all_names:
  159. # For dynamic group (non-static) do not check the target node name since it may not have joined yet
  160. if is_static_group and not set(all_device_maps[source_node].keys()).issubset(
  161. all_names
  162. ):
  163. raise ValueError(
  164. f"Node {source_node} has invalid target node names in its device maps\n"
  165. f"device maps = {all_device_maps[source_node].keys()}\n"
  166. f"node names = {all_names}"
  167. )
  168. for target_node, map_ in all_device_maps[source_node].items():
  169. if len(set(map_.values())) != len(map_):
  170. raise ValueError(
  171. f"Node {source_node} has duplicated target devices "
  172. f"in its device map for {target_node}\n"
  173. f"device map = {map_}"
  174. )
  175. if all_devices[source_node]:
  176. if not set(map_.keys()).issubset(all_devices[source_node]):
  177. raise ValueError(
  178. f"Node {source_node} has unexpected source devices "
  179. f"in its device map for {target_node}\n"
  180. f"device map = {map_}\n"
  181. f"devices = {all_devices[source_node]}"
  182. )
  183. elif not _tensorpipe_validate_devices(
  184. map_.keys(), all_device_counts[source_node]
  185. ):
  186. raise ValueError(
  187. f"Node {source_node} has source devices with invalid indices "
  188. f"in its device map for {target_node}\n"
  189. f"device map = {map_}\n"
  190. f"device count = {all_device_counts[source_node]}"
  191. )
  192. if all_devices.get(target_node, []):
  193. if not set(map_.values()).issubset(all_devices[target_node]):
  194. raise ValueError(
  195. f"Node {source_node} has unexpected target devices "
  196. f"in its device map for {target_node}\n"
  197. f"device map = {map_}\n"
  198. f"devices = {all_devices[target_node]}"
  199. )
  200. elif target_node in all_device_counts and not _tensorpipe_validate_devices(
  201. map_.values(), all_device_counts[target_node]
  202. ):
  203. raise ValueError(
  204. f"Node {source_node} has target devices with invalid indices "
  205. f"in its device map for {target_node}\n"
  206. f"device map = {map_}\n"
  207. f"device count = {all_device_counts[target_node]}"
  208. )
  209. def _create_device_list(my_devices, my_device_maps, reverse_device_maps):
  210. if not my_devices:
  211. devices_set: set[torch.device] = set()
  212. for map_ in my_device_maps.values():
  213. devices_set.update(map_.keys())
  214. for map_ in reverse_device_maps.values():
  215. devices_set.update(map_.keys())
  216. devices_set.discard(torch.device("cpu"))
  217. my_devices = list(devices_set)
  218. my_devices = sorted(my_devices, key=lambda d: d.index)
  219. return my_devices
  220. def _create_reverse_mapping(my_name, all_names, all_device_maps):
  221. reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {}
  222. for node in all_names:
  223. if my_name in all_device_maps[node]:
  224. reverse_device_maps[node] = {
  225. v: k for k, v in all_device_maps[node][my_name].items()
  226. }
  227. return reverse_device_maps
  228. def _get_device_infos():
  229. from . import TensorPipeAgent
  230. agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
  231. opts = agent._get_backend_options()
  232. device_count = torch.cuda.device_count()
  233. if torch.cuda.is_available() and opts.devices:
  234. torch.cuda.init()
  235. return device_count, opts.device_maps, opts.devices
  236. def _set_devices_and_reverse_device_map(agent):
  237. from . import TensorPipeAgent
  238. agent = cast(TensorPipeAgent, agent)
  239. # Group state is retrieved from local agent
  240. # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid
  241. my_worker_info = agent.get_worker_info()
  242. my_name = my_worker_info.name
  243. all_worker_infos = agent.get_worker_infos()
  244. # One round to get device_maps of all workers and construct reverse device maps
  245. all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, []
  246. for worker_info in all_worker_infos:
  247. worker_name = worker_info.name
  248. if worker_name != my_name:
  249. # TODO: make async?
  250. device_count, device_map, devices = api.rpc_sync(
  251. worker_name, _get_device_infos
  252. )
  253. else:
  254. opts = agent._get_backend_options()
  255. device_count, device_map, devices = (
  256. torch.cuda.device_count(),
  257. opts.device_maps,
  258. opts.devices,
  259. )
  260. all_device_counts[worker_name] = device_count
  261. all_device_maps[worker_name] = device_map
  262. all_devices[worker_name] = devices
  263. all_names.append(worker_name)
  264. _validate_device_maps(
  265. all_names,
  266. all_device_counts,
  267. all_device_maps,
  268. all_devices,
  269. is_static_group=False,
  270. )
  271. reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps)
  272. # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps
  273. for worker_name in all_names:
  274. # Set device list for each worker
  275. all_devices[worker_name] = _create_device_list(
  276. all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps
  277. )
  278. api.rpc_sync(
  279. worker_name,
  280. _update_group_membership,
  281. args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True),
  282. )
  283. def _tensorpipe_init_backend_handler(
  284. store, name, rank, world_size, rpc_backend_options
  285. ):
  286. from . import TensorPipeAgent, TensorPipeRpcBackendOptions
  287. if not isinstance(store, dist.Store):
  288. raise TypeError(f"`store` must be a c10d::Store. {store}")
  289. if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions):
  290. raise TypeError(
  291. f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}"
  292. )
  293. device_count = torch.cuda.device_count()
  294. is_static_group = bool(world_size)
  295. # world_size is specified so this is a static group (ranks cannot join and leave)
  296. if is_static_group:
  297. # The agent's join method is required to behave like a barrier and perform
  298. # collective operations, for which it relies on a process group, instead of
  299. # re-implementing this on top of RPCs.
  300. group = _init_process_group(store, rank, world_size)
  301. reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps(
  302. name,
  303. device_count,
  304. rpc_backend_options.device_maps,
  305. rpc_backend_options.devices,
  306. group,
  307. )
  308. if torch.cuda.is_available() and devices:
  309. # It's necessary to initialize PyTorch CUDA states here (e.g.,
  310. # CUDACachingAllocator). If this is missing, we could hit errors like
  311. # "allocator not initialized", because other processes might send
  312. # CUDA-related RPC request to this process before user code in this
  313. # process initializes its PyTorch CUDA states.
  314. torch.cuda.init()
  315. # TODO: add try-except and destroy _agent in all processes if any fails.
  316. agent = TensorPipeAgent(
  317. store,
  318. name,
  319. rank,
  320. world_size,
  321. rpc_backend_options,
  322. reverse_device_maps,
  323. devices,
  324. )
  325. api._init_rpc_states(agent)
  326. # Run one dummy round of RPC to initialize channels/transports. Without
  327. # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC
  328. # on that process before rpc.shutdown(), as the agent initialization can
  329. # take longer than 5s.
  330. api._all_gather(None, timeout=rpc_backend_options.rpc_timeout)
  331. # Need a barrier here to make sure no peers leave before the rank0 finishes
  332. # _all_gather
  333. group.barrier().wait()
  334. return agent
  335. # initialization for dynamic rpc (ranks can join and leave)
  336. else:
  337. with _group_membership_management(store, name, True):
  338. # Construct TPAgent with empty reverse_device_map and devices
  339. # these properties will be updated after initialization
  340. agent = TensorPipeAgent(
  341. store,
  342. name,
  343. rank,
  344. world_size,
  345. rpc_backend_options,
  346. {},
  347. [],
  348. )
  349. api._init_rpc_states(agent)
  350. try:
  351. # Notify all workers in group this rank has joined and set devices and reverse_device_map
  352. # This is a synchronous operation that completes once all existing ranks are updated
  353. _set_devices_and_reverse_device_map(agent)
  354. except Exception:
  355. api.shutdown()
  356. raise
  357. return agent
  358. register_backend(
  359. "TENSORPIPE",
  360. _tensorpipe_construct_rpc_backend_options_handler,
  361. _tensorpipe_init_backend_handler,
  362. )