api.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import os
  9. import sys
  10. import uuid
  11. from collections.abc import Callable
  12. from dataclasses import dataclass, field
  13. from typing import Any
  14. import torch
  15. import torch.distributed.elastic.rendezvous.registry as rdzv_registry
  16. from torch._utils_internal import get_default_numa_options
  17. from torch.distributed.elastic import events, metrics
  18. from torch.distributed.elastic.agent.server.api import WorkerSpec
  19. from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent
  20. from torch.distributed.elastic.multiprocessing import (
  21. DefaultLogsSpecs,
  22. LogsSpecs,
  23. SignalException,
  24. )
  25. from torch.distributed.elastic.multiprocessing.errors import ChildFailedError
  26. from torch.distributed.elastic.rendezvous import RendezvousParameters
  27. from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
  28. from torch.distributed.elastic.utils.logging import get_logger
  29. from torch.numa.binding import NumaOptions
  30. __all__ = ["LaunchConfig", "elastic_launch", "launch_agent"]
  31. logger = get_logger(__name__)
  32. @dataclass
  33. class LaunchConfig:
  34. """
  35. Creates a rendezvous config.
  36. Args:
  37. min_nodes: Minimum amount of nodes that the user function will
  38. be launched on. Elastic agent ensures that the user
  39. function start only when the min_nodes amount enters
  40. the rendezvous.
  41. max_nodes: Maximum amount of nodes that the user function
  42. will be launched on.
  43. nproc_per_node: On each node the elastic agent will launch
  44. this amount of workers that will execute user
  45. defined function.
  46. rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
  47. rdzv_endpoint: The endpoint of the rdzv sync. storage.
  48. rdzv_configs: Key, value pair that specifies rendezvous specific configuration.
  49. rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going
  50. to be removed in future versions, see the note below. The default timeout is 900 seconds.
  51. run_id: The unique run id of the job (if not passed a unique one will be
  52. deduced from run environment - flow workflow id in flow - or auto generated).
  53. role: User defined role of the worker (defaults to "trainer").
  54. max_restarts: The maximum amount of restarts that elastic agent will conduct
  55. on workers before failure.
  56. monitor_interval: The interval in seconds that is used by the elastic_agent
  57. as a period of monitoring workers.
  58. start_method: The method is used by the elastic agent to start the
  59. workers (spawn, fork, forkserver).
  60. metrics_cfg: configuration to initialize metrics.
  61. local_addr: address of the local node if any. If not set, a lookup on the local
  62. machine's FQDN will be performed.
  63. local_ranks_filter: ranks for which to show logs in console. If not set, show from all.
  64. event_log_handler: name of the event logging handler as registered in
  65. `elastic/events/handlers.py <https://docs.pytorch.org/docs/stable/elastic/events.html>`_.
  66. duplicate_stdout_filters: If non-empty, duplicates stdout to a file containing only lines
  67. that match _any_ of the filter strings.
  68. duplicate_stderr_filters: If non-empty, duplicates stderr to a file containing only lines
  69. that match _any_ of the filter strings.
  70. virtual_local_rank: Enable virtual local rank mode for workers (defaults to False).
  71. When enabled, LOCAL_RANK is set to 0 for all workers and
  72. CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its
  73. assigned GPU at device index 0.
  74. .. note::
  75. `rdzv_timeout` is a legacy argument that will be removed in future.
  76. Set the timeout via `rdzv_configs['timeout']`
  77. """
  78. min_nodes: int
  79. max_nodes: int
  80. nproc_per_node: int
  81. logs_specs: LogsSpecs | None = None
  82. run_id: str = ""
  83. role: str = "default_role"
  84. rdzv_endpoint: str = ""
  85. rdzv_backend: str = "etcd"
  86. rdzv_configs: dict[str, Any] = field(default_factory=dict)
  87. rdzv_timeout: int = -1
  88. max_restarts: int = 3
  89. monitor_interval: float = 0.1
  90. start_method: str = "spawn"
  91. log_line_prefix_template: str | None = None
  92. metrics_cfg: dict[str, str] = field(default_factory=dict)
  93. local_addr: str | None = None
  94. event_log_handler: str = "null"
  95. numa_options: NumaOptions | None = None
  96. signals_to_handle: str = "SIGTERM,SIGINT,SIGHUP,SIGQUIT"
  97. duplicate_stdout_filters: list[str] | None = None
  98. duplicate_stderr_filters: list[str] | None = None
  99. virtual_local_rank: bool = False
  100. def __post_init__(self):
  101. default_timeout = 900
  102. if self.rdzv_timeout != -1:
  103. self.rdzv_configs["timeout"] = self.rdzv_timeout
  104. elif "timeout" not in self.rdzv_configs:
  105. self.rdzv_configs["timeout"] = default_timeout
  106. # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage
  107. if self.logs_specs is None:
  108. self.logs_specs = DefaultLogsSpecs()
  109. if (
  110. self.numa_options is None
  111. and torch.cuda.is_available()
  112. # We assume local_rank n uses cuda device n.
  113. and torch.cuda.device_count() == self.nproc_per_node
  114. ):
  115. self.numa_options = get_default_numa_options()
  116. logger.info("Using default numa options = %r", self.numa_options)
  117. class elastic_launch:
  118. """
  119. Launches an torchelastic agent on the container that invoked the entrypoint.
  120. 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/
  121. ``entrypoint`` can be a function or a command.
  122. 2. The return value is a map of each worker's output mapped
  123. by their respective global rank.
  124. Usage
  125. ::
  126. def worker_fn(foo):
  127. # ...
  128. def main():
  129. # entrypoint is a function.
  130. outputs = elastic_launch(LaunchConfig, worker_fn)(foo)
  131. # return rank 0's output
  132. return outputs[0]
  133. # entrypoint is a command and ``script.py`` is the python module.
  134. outputs = elastic_launch(LaunchConfig, "script.py")(args)
  135. outputs = elastic_launch(LaunchConfig, "python")("script.py")
  136. """
  137. def __init__(
  138. self,
  139. config: LaunchConfig,
  140. entrypoint: Callable | str | None,
  141. ):
  142. self._config = config
  143. self._entrypoint = entrypoint
  144. def __call__(self, *args):
  145. return launch_agent(self._config, self._entrypoint, list(args))
  146. def _get_entrypoint_name(entrypoint: Callable | str | None, args: list[Any]) -> str:
  147. """Retrieve entrypoint name with the rule:
  148. 1. If entrypoint is a function, use ``entrypoint.__qualname__``.
  149. 2. If entrypoint is a string, check its value:
  150. 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args``
  151. which does not start with hifen letter (for example, "-u" will be skipped).
  152. 2.2 otherwise, use ``entrypoint`` value.
  153. 3. Otherwise, return empty string.
  154. """
  155. if isinstance(entrypoint, Callable): # type: ignore[arg-type]
  156. return entrypoint.__name__ # type: ignore[union-attr]
  157. elif isinstance(entrypoint, str):
  158. if entrypoint == sys.executable:
  159. return next((arg for arg in args if arg[0] != "-"), "")
  160. else:
  161. return entrypoint
  162. else:
  163. return ""
  164. def _get_addr_and_port(
  165. rdzv_parameters: RendezvousParameters,
  166. ) -> tuple[str | None, int | None]:
  167. if rdzv_parameters.backend != "static":
  168. return (None, None)
  169. endpoint = rdzv_parameters.endpoint
  170. endpoint = endpoint.strip()
  171. if not endpoint:
  172. raise ValueError(
  173. "Endpoint is missing in endpoint. Try to add --master-addr and --master-port"
  174. )
  175. master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
  176. if master_port == -1:
  177. raise ValueError(
  178. f"port is missing in endpoint: {endpoint}. Try to specify --master-port"
  179. )
  180. return (master_addr, master_port)
  181. def launch_agent(
  182. config: LaunchConfig,
  183. entrypoint: Callable | str | None,
  184. args: list[Any],
  185. ) -> dict[int, Any]:
  186. if not config.run_id:
  187. run_id = str(uuid.uuid4().int)
  188. logger.warning("config has no run_id, generated a random run_id: %s", run_id)
  189. config.run_id = run_id
  190. entrypoint_name = _get_entrypoint_name(entrypoint, args)
  191. logger.info(
  192. "Starting elastic_operator with launch configs:\n"
  193. " entrypoint : %(entrypoint)s\n"
  194. " min_nodes : %(min_nodes)s\n"
  195. " max_nodes : %(max_nodes)s\n"
  196. " nproc_per_node : %(nproc_per_node)s\n"
  197. " run_id : %(run_id)s\n"
  198. " rdzv_backend : %(rdzv_backend)s\n"
  199. " rdzv_endpoint : %(rdzv_endpoint)s\n"
  200. " rdzv_configs : %(rdzv_configs)s\n"
  201. " max_restarts : %(max_restarts)s\n"
  202. " monitor_interval : %(monitor_interval)s\n"
  203. " log_dir : %(log_dir)s\n"
  204. " metrics_cfg : %(metrics_cfg)s\n"
  205. " event_log_handler : %(event_log_handler)s\n"
  206. " numa_options : %(numa_options)s\n"
  207. " signals_to_handle : %(signals_to_handle)s\n"
  208. " duplicate_stdout_filters : %(duplicate_stdout_filters)s\n"
  209. " duplicate_stderr_filters : %(duplicate_stderr_filters)s\n",
  210. {
  211. "entrypoint": entrypoint_name,
  212. "min_nodes": config.min_nodes,
  213. "max_nodes": config.max_nodes,
  214. "nproc_per_node": config.nproc_per_node,
  215. "run_id": config.run_id,
  216. "rdzv_backend": config.rdzv_backend,
  217. "rdzv_endpoint": config.rdzv_endpoint,
  218. "rdzv_configs": config.rdzv_configs,
  219. "max_restarts": config.max_restarts,
  220. "monitor_interval": config.monitor_interval,
  221. "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr]
  222. "metrics_cfg": config.metrics_cfg,
  223. "event_log_handler": config.event_log_handler,
  224. "numa_options": config.numa_options,
  225. "signals_to_handle": config.signals_to_handle,
  226. "duplicate_stdout_filters": config.duplicate_stdout_filters,
  227. "duplicate_stderr_filters": config.duplicate_stderr_filters,
  228. },
  229. )
  230. rdzv_parameters = RendezvousParameters(
  231. backend=config.rdzv_backend,
  232. endpoint=config.rdzv_endpoint,
  233. run_id=config.run_id,
  234. min_nodes=config.min_nodes,
  235. max_nodes=config.max_nodes,
  236. local_addr=config.local_addr,
  237. **config.rdzv_configs,
  238. )
  239. master_addr, master_port = _get_addr_and_port(rdzv_parameters)
  240. # Set the signals to handle in the environment variable
  241. os.environ["TORCHELASTIC_SIGNALS_TO_HANDLE"] = config.signals_to_handle
  242. spec = WorkerSpec(
  243. role=config.role,
  244. local_world_size=config.nproc_per_node,
  245. entrypoint=entrypoint,
  246. args=tuple(args),
  247. rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters),
  248. max_restarts=config.max_restarts,
  249. monitor_interval=config.monitor_interval,
  250. master_addr=master_addr,
  251. master_port=master_port,
  252. local_addr=config.local_addr,
  253. event_log_handler=config.event_log_handler,
  254. numa_options=config.numa_options,
  255. duplicate_stdout_filters=config.duplicate_stdout_filters,
  256. duplicate_stderr_filters=config.duplicate_stderr_filters,
  257. virtual_local_rank=config.virtual_local_rank,
  258. )
  259. agent = LocalElasticAgent(
  260. spec=spec,
  261. logs_specs=config.logs_specs, # type: ignore[arg-type]
  262. start_method=config.start_method,
  263. log_line_prefix_template=config.log_line_prefix_template,
  264. )
  265. shutdown_rdzv = True
  266. try:
  267. metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg))
  268. result = agent.run()
  269. # records that agent.run() has succeeded NOT that workers have succeeded
  270. events.record(agent.get_event_succeeded(), config.event_log_handler)
  271. if result.is_failed():
  272. # ChildFailedError is treated specially by @record
  273. # if the error files for the failed children exist
  274. # @record will copy the first error (root cause)
  275. # to the error file of the launcher process.
  276. raise ChildFailedError(
  277. name=entrypoint_name,
  278. failures=result.failures,
  279. )
  280. return result.return_values
  281. except ChildFailedError:
  282. raise
  283. except SignalException:
  284. # when the agent dies with a signal do NOT shutdown the rdzv_handler
  285. # since this closes the rendezvous on this rdzv_id permanently and
  286. # prevents any additional scaling events
  287. shutdown_rdzv = False
  288. events.record(agent.get_event_failed(), config.event_log_handler)
  289. raise
  290. except Exception:
  291. events.record(agent.get_event_failed(), config.event_log_handler)
  292. raise
  293. finally:
  294. if shutdown_rdzv:
  295. spec.rdzv_handler.shutdown()