api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. import socket
  8. from abc import ABC, abstractmethod
  9. from collections.abc import Callable
  10. from dataclasses import dataclass
  11. from typing import Any, ClassVar
  12. from torch.distributed import Store
  13. from torch.distributed.elastic.utils.distributed import get_free_port
  14. __all__ = [
  15. "RendezvousClosedError",
  16. "RendezvousConnectionError",
  17. "RendezvousError",
  18. "RendezvousGracefulExitError",
  19. "RendezvousHandler",
  20. "RendezvousHandlerCreator",
  21. "RendezvousHandlerRegistry",
  22. "RendezvousInfo",
  23. "RendezvousParameters",
  24. "RendezvousStateError",
  25. "RendezvousStoreInfo",
  26. "RendezvousTimeoutError",
  27. "rendezvous_handler_registry",
  28. ]
  29. class RendezvousError(Exception):
  30. """Represents the base type for rendezvous errors."""
  31. class RendezvousClosedError(RendezvousError):
  32. """Raised when a rendezvous is closed."""
  33. class RendezvousTimeoutError(RendezvousError):
  34. """Raised when a rendezvous did not complete on time."""
  35. class RendezvousConnectionError(RendezvousError):
  36. """Raised when the connection to a rendezvous backend has failed."""
  37. class RendezvousStateError(RendezvousError):
  38. """Raised when the state of a rendezvous is corrupt."""
  39. class RendezvousGracefulExitError(RendezvousError):
  40. """Raised when node wasn't not included in rendezvous and gracefully exits.
  41. Exception is a mechanism to exit the stack, however does not mean a failure.
  42. """
  43. @dataclass
  44. class RendezvousStoreInfo:
  45. """Store address and port that can be used to bootstrap trainer distributed comms"""
  46. MASTER_ADDR_KEY: ClassVar[str] = "MASTER_ADDR"
  47. MASTER_PORT_KEY: ClassVar[str] = "MASTER_PORT"
  48. master_addr: str
  49. master_port: int
  50. @staticmethod
  51. def build(
  52. rank: int,
  53. store: Store,
  54. local_addr: str | None,
  55. server_port: int | None = None,
  56. ) -> "RendezvousStoreInfo":
  57. """Factory method, finds unused new port on rank0 host and addr/port info with all ranks.
  58. If master_addr/master_port is knowns (useful when sharing existing tcp store server) use the constructor.
  59. Args:
  60. rank: rank of the current node
  61. store: store to use for rendezvous
  62. local_addr: address of the current node, if not provided will be resolved from hostname
  63. server_port: port of the TCPStore server, when the TCPStore is shared.
  64. """
  65. # TODO swap to collectives comms API
  66. if rank == 0:
  67. addr = local_addr or socket.getfqdn()
  68. # When TCPStore is not shared, we fallback to get_free_port.
  69. port = server_port or get_free_port()
  70. store.set(
  71. RendezvousStoreInfo.MASTER_ADDR_KEY,
  72. addr.encode(encoding="UTF-8"), # type: ignore[arg-type]
  73. )
  74. store.set(
  75. RendezvousStoreInfo.MASTER_PORT_KEY,
  76. str(port).encode(encoding="UTF-8"), # type: ignore[arg-type]
  77. )
  78. addr = store.get(RendezvousStoreInfo.MASTER_ADDR_KEY).decode(encoding="UTF-8")
  79. port = int(
  80. store.get(RendezvousStoreInfo.MASTER_PORT_KEY).decode(encoding="UTF-8")
  81. )
  82. return RendezvousStoreInfo(master_addr=addr, master_port=port)
  83. class RendezvousInfo:
  84. """Holds the information about the rendezvous."""
  85. def __init__(
  86. self,
  87. store: Store,
  88. rank: int,
  89. world_size: int,
  90. bootstrap_store_info: RendezvousStoreInfo,
  91. ):
  92. self._store = store
  93. self._rank = rank
  94. self._world_size = world_size
  95. self._bootstrap_store_info = bootstrap_store_info
  96. @property
  97. def store(self) -> Store:
  98. """Store used by torchelastic control plane"""
  99. return self._store
  100. @property
  101. def rank(self) -> int:
  102. """Rank within a group"""
  103. return self._rank
  104. @property
  105. def world_size(self) -> int:
  106. """Global group size"""
  107. return self._world_size
  108. @property
  109. def bootstrap_store_info(self) -> RendezvousStoreInfo | None:
  110. """Store information that can used by trainer code to bootstrap distributed comms."""
  111. return self._bootstrap_store_info
  112. class RendezvousHandler(ABC):
  113. """Main rendezvous interface.
  114. Note:
  115. Distributed Torch users normally **do not** need to implement their own
  116. ``RendezvousHandler``. An implementation based on C10d Store is already
  117. provided, and is recommended for most users.
  118. """
  119. @abstractmethod
  120. def get_backend(self) -> str:
  121. """Return the name of the rendezvous backend."""
  122. @property
  123. def use_agent_store(self) -> bool:
  124. """Indicates that store reference returned by :py:meth:`next_rendezvous` can be shared with user
  125. applications and will be available during application lifecycle.
  126. Rendezvous handler impl will share store details as instance of :py:class:`RendezvousStoreInfo`.
  127. Applications as a convention use `MASTER_ADDR`/`MASTER_PORT` env variables to lookup the store.
  128. """
  129. return False
  130. @abstractmethod
  131. def next_rendezvous(self) -> RendezvousInfo:
  132. """Main entry-point into the rendezvous barrier.
  133. Blocks until the rendezvous is complete and the current process is
  134. included in the formed worker group, or a timeout occurs, or the
  135. rendezvous was marked closed.
  136. Returns:
  137. Instance of :py:class:`RendezvousInfo`.
  138. Raises:
  139. RendezvousClosedError:
  140. The rendezvous is closed.
  141. RendezvousConnectionError:
  142. The connection to the rendezvous backend has failed.
  143. RendezvousStateError:
  144. The rendezvous state is corrupt.
  145. RendezvousTimeoutError:
  146. The rendezvous did not complete on time.
  147. """
  148. @abstractmethod
  149. def is_closed(self) -> bool:
  150. """Check whether the rendezvous has been closed.
  151. A closed rendezvous means all future attempts to re-rendezvous within
  152. same job will fail.
  153. ``is_closed()`` and :py:meth:`set_closed` have semantics of eventual
  154. propagation and should not be used for synchronization. The intention is
  155. that if at least one node decides the job is finished, it will close the
  156. rendezvous, and other nodes will soon observe this and stop running as
  157. well.
  158. """
  159. @abstractmethod
  160. def set_closed(self):
  161. """Mark the rendezvous as closed."""
  162. @abstractmethod
  163. def num_nodes_waiting(self) -> int:
  164. """Return the number of nodes who arrived late at the rendezvous
  165. barrier, hence were not included in the current worker group.
  166. Callers should periodically call this method to check whether new
  167. nodes are waiting to join the job and if so admit them by calling
  168. :py:meth:`next_rendezvous()` (re-rendezvous).
  169. """
  170. @abstractmethod
  171. def get_run_id(self) -> str:
  172. """Return the run id of the rendezvous.
  173. The run id is a user-defined id that uniquely identifies an instance of
  174. a distributed application. It typically maps to a job id and is used to
  175. allow nodes to join the correct distributed application.
  176. """
  177. @abstractmethod
  178. def shutdown(self) -> bool:
  179. """Close all resources that were open for the rendezvous.
  180. Example::
  181. rdzv_handler = ...
  182. try:
  183. store, rank, world_size = rdzv_handler.next_rendezvous()
  184. finally:
  185. rdzv_handler.shutdown()
  186. """
  187. class RendezvousParameters:
  188. """Hold the parameters to construct a :py:class:`RendezvousHandler`.
  189. Args:
  190. backend:
  191. The name of the backend to use to handle the rendezvous.
  192. endpoint:
  193. The endpoint of the rendezvous, usually in form <hostname>[:<port>].
  194. run_id:
  195. The id of the rendezvous.
  196. min_nodes:
  197. The minimum number of nodes to admit to the rendezvous.
  198. max_nodes:
  199. The maximum number of nodes to admit to the rendezvous.
  200. local_addr:
  201. The address of the local node.
  202. **kwargs:
  203. Additional parameters for the specified backend.
  204. """
  205. def __init__(
  206. self,
  207. backend: str,
  208. endpoint: str,
  209. run_id: str,
  210. min_nodes: int,
  211. max_nodes: int,
  212. local_addr: str | None = None,
  213. **kwargs,
  214. ):
  215. if not backend:
  216. raise ValueError("The rendezvous backend name must be a non-empty string.")
  217. if min_nodes < 1:
  218. raise ValueError(
  219. f"The minimum number of rendezvous nodes ({min_nodes}) must be greater than zero."
  220. )
  221. if max_nodes < min_nodes:
  222. raise ValueError(
  223. f"The maximum number of rendezvous nodes ({max_nodes}) must be greater than or "
  224. f"equal to the minimum number of rendezvous nodes ({min_nodes})."
  225. )
  226. self.backend = backend
  227. self.endpoint = endpoint
  228. self.run_id = run_id
  229. self.min_nodes = min_nodes
  230. self.max_nodes = max_nodes
  231. self.config = kwargs
  232. self.local_addr = local_addr
  233. def get(self, key: str, default: Any = None) -> Any:
  234. """Return the value for ``key`` if ``key`` exists, else ``default``."""
  235. return self.config.get(key, default)
  236. def get_as_bool(self, key: str, default: bool | None = None) -> bool | None:
  237. """Return the value for ``key`` as a ``bool``."""
  238. value = self.get(key, default)
  239. if value is None or isinstance(value, bool):
  240. return value
  241. if isinstance(value, int):
  242. if value == 1:
  243. return True
  244. if value == 0:
  245. return False
  246. elif isinstance(value, str):
  247. if value.lower() in ["1", "true", "t", "yes", "y"]:
  248. return True
  249. if value.lower() in ["0", "false", "f", "no", "n"]:
  250. return False
  251. raise ValueError(
  252. f"The rendezvous configuration option '{key}' does not represent a valid boolean value."
  253. )
  254. def get_as_int(self, key: str, default: int | None = None) -> int | None:
  255. """Return the value for ``key`` as an ``int``."""
  256. value = self.get(key, default)
  257. if value is None:
  258. return value
  259. try:
  260. return int(value)
  261. except ValueError as e:
  262. raise ValueError(
  263. f"The rendezvous configuration option '{key}' does not represent a valid integer "
  264. "value."
  265. ) from e
  266. RendezvousHandlerCreator = Callable[[RendezvousParameters], RendezvousHandler]
  267. class RendezvousHandlerRegistry:
  268. """Represent a registry of :py:class:`RendezvousHandler` backends."""
  269. _registry: dict[str, RendezvousHandlerCreator]
  270. def __init__(self) -> None:
  271. self._registry = {}
  272. def register(self, backend: str, creator: RendezvousHandlerCreator) -> None:
  273. """Register a new rendezvous backend.
  274. Args:
  275. backend:
  276. The name of the backend.
  277. creator:
  278. The callback to invoke to construct the
  279. :py:class:`RendezvousHandler`.
  280. """
  281. if not backend:
  282. raise ValueError("The rendezvous backend name must be a non-empty string.")
  283. current_creator: RendezvousHandlerCreator | None
  284. try:
  285. current_creator = self._registry[backend]
  286. except KeyError:
  287. current_creator = None
  288. if current_creator is not None and current_creator != creator:
  289. raise ValueError(
  290. f"The rendezvous backend '{backend}' cannot be registered with '{creator}' as it "
  291. f"is already registered with '{current_creator}'."
  292. )
  293. self._registry[backend] = creator
  294. def create_handler(self, params: RendezvousParameters) -> RendezvousHandler:
  295. """Create a new :py:class:`RendezvousHandler`."""
  296. try:
  297. creator = self._registry[params.backend]
  298. except KeyError as e:
  299. raise ValueError(
  300. f"The rendezvous backend '{params.backend}' is not registered. Did you forget "
  301. f"to call `{self.register.__name__}`?"
  302. ) from e
  303. handler = creator(params)
  304. # Do some sanity check.
  305. if handler.get_backend() != params.backend:
  306. raise RuntimeError(
  307. f"The rendezvous backend '{handler.get_backend()}' does not match the requested "
  308. f"backend '{params.backend}'."
  309. )
  310. return handler
  311. # The default global registry instance used by launcher scripts to instantiate
  312. # rendezvous handlers.
  313. rendezvous_handler_registry = RendezvousHandlerRegistry()