run.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. """
  9. Module ``torch.distributed.run``.
  10. ``torch.distributed.run`` is a module that spawns up multiple distributed
  11. training processes on each of the training nodes.
  12. ``torchrun`` is a python
  13. `console script <https://packaging.python.org/en/latest/specifications/entry-points/#use-for-scripts>`_
  14. to the main module
  15. `torch.distributed.run <https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py>`_
  16. declared in the ``entry_points`` configuration in
  17. `setup.py <https://github.com/pytorch/pytorch/blob/master/setup.py>`_.
  18. It is equivalent to invoking ``python -m torch.distributed.run``.
  19. ``torchrun`` can be used for single-node distributed training, in which one or
  20. more processes per node will be spawned. It can be used for either
  21. CPU training or GPU training. If it is used for GPU training,
  22. each distributed process will be operating on a single GPU. This can achieve
  23. well-improved single-node training performance. ``torchrun`` can also be used in
  24. multi-node distributed training, by spawning up multiple processes on each node
  25. for well-improved multi-node distributed training performance as well.
  26. This will especially be beneficial for systems with multiple Infiniband
  27. interfaces that have direct-GPU support, since all of them can be utilized for
  28. aggregated communication bandwidth.
  29. In both cases of single-node distributed training or multi-node distributed
  30. training, ``torchrun`` will launch the given number of processes per node
  31. (``--nproc-per-node``). If used for GPU training, this number needs to be less
  32. or equal to the number of GPUs on the current system (``nproc_per_node``),
  33. and each process will be operating on a single GPU from *GPU 0 to
  34. GPU (nproc_per_node - 1)*.
  35. .. versionchanged:: 2.0.0
  36. ``torchrun`` will pass the ``--local-rank=<rank>`` argument to your script.
  37. From PyTorch 2.0.0 onwards, the dashed ``--local-rank`` is preferred over the
  38. previously used underscored ``--local_rank``.
  39. For backward compatibility, it may be necessary for users to handle both
  40. cases in their argument parsing code. This means including both ``"--local-rank"``
  41. and ``"--local_rank"`` in the argument parser. If only ``"--local_rank"`` is
  42. provided, ``torchrun`` will trigger an error: "error: unrecognized arguments:
  43. --local-rank=<rank>". For training code that only supports PyTorch 2.0.0+,
  44. including ``"--local-rank"`` should be sufficient.
  45. ::
  46. >>> # xdoctest: +SKIP
  47. >>> import argparse
  48. >>> parser = argparse.ArgumentParser()
  49. >>> parser.add_argument("--local-rank", "--local_rank", type=int)
  50. >>> args = parser.parse_args()
  51. Usage
  52. -----
  53. Single-node multi-worker
  54. ++++++++++++++++++++++++
  55. ::
  56. torchrun
  57. --standalone
  58. --nnodes=1
  59. --nproc-per-node=$NUM_TRAINERS
  60. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  61. .. note:: ``--nproc-per-node`` may be
  62. ``"gpu"`` (spawn one process per GPU),
  63. ``"cpu"`` (spawn one process per CPU),
  64. ``"xpu"`` (spawn one process per XPU),
  65. ``"auto"`` (equivalent to ``"gpu"`` if CUDA is available,
  66. else equivalent to ``"xpu"`` if XPU is available,
  67. else equivalent to ``"cpu"``),
  68. or an integer specifying the number of processes.
  69. See `torch.distributed.run.determine_local_world_size
  70. <https://github.com/pytorch/pytorch/blob/0a94bb432ed75cc2d950d81b2921363218a7e459/torch/distributed/run.py#L673-L716>`_
  71. for more details.
  72. Stacked single-node multi-worker
  73. ++++++++++++++++++++++++++++++++
  74. To run multiple instances (separate jobs) of single-node, multi-worker on the
  75. same host, we need to make sure that each instance (job) is
  76. setup on different ports to avoid port conflicts (or worse, two jobs being merged
  77. as a single job). To do this you have to run with ``--rdzv-backend=c10d``
  78. and specify a different port by setting ``--rdzv-endpoint=localhost:$PORT_k``.
  79. For ``--nodes=1``, its often convenient to let ``torchrun`` pick a free random
  80. port automatically instead of manually assigning different ports for each run.
  81. ::
  82. torchrun
  83. --rdzv-backend=c10d
  84. --rdzv-endpoint=localhost:0
  85. --nnodes=1
  86. --nproc-per-node=$NUM_TRAINERS
  87. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  88. Fault tolerant (fixed sized number of workers, no elasticity, tolerates 3 failures)
  89. +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  90. ::
  91. torchrun
  92. --nnodes=$NUM_NODES
  93. --nproc-per-node=$NUM_TRAINERS
  94. --max-restarts=3
  95. --rdzv-id=$JOB_ID
  96. --rdzv-backend=c10d
  97. --rdzv-endpoint=$HOST_NODE_ADDR
  98. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  99. ``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
  100. the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
  101. node in your training cluster, but ideally you should pick a node that has a high bandwidth.
  102. .. note::
  103. If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
  104. Elastic (``min=1``, ``max=4``, tolerates up to 3 membership changes or failures)
  105. ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
  106. ::
  107. torchrun
  108. --nnodes=1:4
  109. --nproc-per-node=$NUM_TRAINERS
  110. --max-restarts=3
  111. --rdzv-id=$JOB_ID
  112. --rdzv-backend=c10d
  113. --rdzv-endpoint=$HOST_NODE_ADDR
  114. YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)
  115. ``HOST_NODE_ADDR``, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and
  116. the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any
  117. node in your training cluster, but ideally you should pick a node that has a high bandwidth.
  118. .. note::
  119. If no port number is specified ``HOST_NODE_ADDR`` defaults to 29400.
  120. Note on rendezvous backend
  121. --------------------------
  122. For multi-node training you need to specify:
  123. 1. ``--rdzv-id``: A unique job id (shared by all nodes participating in the job)
  124. 2. ``--rdzv-backend``: An implementation of
  125. :py:class:`torch.distributed.elastic.rendezvous.RendezvousHandler`
  126. 3. ``--rdzv-endpoint``: The endpoint where the rendezvous backend is running; usually in form
  127. ``host:port``.
  128. Currently ``c10d`` (recommended), ``etcd-v2``, and ``etcd`` (legacy) rendezvous backends are
  129. supported out of the box. To use ``etcd-v2`` or ``etcd``, setup an etcd server with the ``v2`` api
  130. enabled (e.g. ``--enable-v2``).
  131. .. warning::
  132. ``etcd-v2`` and ``etcd`` rendezvous use etcd API v2. You MUST enable the v2 API on the etcd
  133. server. Our tests use etcd v3.4.3.
  134. .. warning::
  135. For etcd-based rendezvous we recommend using ``etcd-v2`` over ``etcd`` which is functionally
  136. equivalent, but uses a revised implementation. ``etcd`` is in maintenance mode and will be
  137. removed in a future version.
  138. Definitions
  139. -----------
  140. 1. ``Node`` - A physical instance or a container; maps to the unit that the job manager works with.
  141. 2. ``Worker`` - A worker in the context of distributed training.
  142. 3. ``WorkerGroup`` - The set of workers that execute the same function (e.g. trainers).
  143. 4. ``LocalWorkerGroup`` - A subset of the workers in the worker group running on the same node.
  144. 5. ``RANK`` - The rank of the worker within a worker group.
  145. 6. ``WORLD_SIZE`` - The total number of workers in a worker group.
  146. 7. ``LOCAL_RANK`` - The rank of the worker within a local worker group.
  147. 8. ``LOCAL_WORLD_SIZE`` - The size of the local worker group.
  148. 9. ``rdzv_id`` - A user-defined id that uniquely identifies the worker group for a job. This id is
  149. used by each node to join as a member of a particular worker group.
  150. 9. ``rdzv_backend`` - The backend of the rendezvous (e.g. ``c10d``). This is typically a strongly
  151. consistent key-value store.
  152. 10. ``rdzv_endpoint`` - The rendezvous backend endpoint; usually in form ``<host>:<port>``.
  153. A ``Node`` runs ``LOCAL_WORLD_SIZE`` workers which comprise a ``LocalWorkerGroup``. The union of
  154. all ``LocalWorkerGroups`` in the nodes in the job comprise the ``WorkerGroup``.
  155. Environment Variables
  156. ---------------------
  157. The following environment variables are made available to you in your script:
  158. 1. ``LOCAL_RANK`` - The local rank.
  159. 2. ``RANK`` - The global rank.
  160. 3. ``GROUP_RANK`` - The rank of the worker group. A number between 0 and ``max_nnodes``. When
  161. running a single worker group per node, this is the rank of the node.
  162. 4. ``ROLE_RANK`` - The rank of the worker across all the workers that have the same role. The role
  163. of the worker is specified in the ``WorkerSpec``.
  164. 5. ``LOCAL_WORLD_SIZE`` - The local world size (e.g. number of workers running locally); equals to
  165. ``--nproc-per-node`` specified on ``torchrun``.
  166. 6. ``WORLD_SIZE`` - The world size (total number of workers in the job).
  167. 7. ``ROLE_WORLD_SIZE`` - The total number of workers that was launched with the same role specified
  168. in ``WorkerSpec``.
  169. 8. ``MASTER_ADDR`` - The FQDN of the host that is running worker with rank 0; used to initialize
  170. the Torch Distributed backend.
  171. 9. ``MASTER_PORT`` - The port on the ``MASTER_ADDR`` that can be used to host the C10d TCP store.
  172. 10. ``TORCHELASTIC_RESTART_COUNT`` - The number of worker group restarts so far.
  173. 11. ``TORCHELASTIC_MAX_RESTARTS`` - The configured maximum number of restarts.
  174. 12. ``TORCHELASTIC_RUN_ID`` - Equal to the rendezvous ``run_id`` (e.g. unique job id).
  175. 13. ``PYTHON_EXEC`` - System executable override. If provided, the python user script will
  176. use the value of ``PYTHON_EXEC`` as executable. The `sys.executable` is used by default.
  177. Deployment
  178. ----------
  179. 1. (Not needed for the C10d backend) Start the rendezvous backend server and get the endpoint (to be
  180. passed as ``--rdzv-endpoint`` to ``torchrun``)
  181. 2. Single-node multi-worker: Start ``torchrun`` on the host to start the agent process which
  182. creates and monitors a local worker group.
  183. 3. Multi-node multi-worker: Start ``torchrun`` with the same arguments on all the nodes
  184. participating in training.
  185. When using a job/cluster manager, the entry point command to the multi-node job should be ``torchrun``.
  186. Failure Modes
  187. -------------
  188. 1. Worker failure: For a training job with ``n`` workers, if ``k<=n`` workers fail all workers
  189. are stopped and restarted up to ``max_restarts``.
  190. 2. Agent failure: An agent failure results in a local worker group failure. It is up to the job
  191. manager to fail the entire job (gang semantics) or attempt to replace the node. Both behaviors
  192. are supported by the agent.
  193. 3. Node failure: Same as agent failure.
  194. Membership Changes
  195. ------------------
  196. 1. Node departure (scale-down): The agent is notified of the departure, all existing workers are
  197. stopped, a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
  198. ``WORLD_SIZE``.
  199. 2. Node arrival (scale-up): The new node is admitted to the job, all existing workers are stopped,
  200. a new ``WorkerGroup`` is formed, and all workers are started with a new ``RANK`` and
  201. ``WORLD_SIZE``.
  202. NUMA Binding
  203. ------------
  204. On multi-GPU systems with NUMA (Non-Uniform Memory Access) architecture, you can improve
  205. performance by binding worker processes to CPUs near their assigned GPUs. Use the
  206. ``--numa-binding`` flag:
  207. ::
  208. torchrun --numa-binding=node --nproc-per-node=8 train.py
  209. See :ref:`numa-api` for more details.
  210. Important Notices
  211. -----------------
  212. 1. This utility and multi-process distributed (single-node or
  213. multi-node) GPU training currently only achieves the best performance using
  214. the NCCL distributed backend. Thus NCCL backend is the recommended backend to
  215. use for GPU training.
  216. 2. The environment variables necessary to initialize a Torch process group are provided to you by
  217. this module, no need for you to pass ``RANK`` manually. To initialize a process group in your
  218. training script, simply run:
  219. ::
  220. >>> # xdoctest: +SKIP("stub")
  221. >>> import torch.distributed as dist
  222. >>> dist.init_process_group(backend="gloo|nccl")
  223. 3. In your training program, you can either use regular distributed functions
  224. or use :func:`torch.nn.parallel.DistributedDataParallel` module. If your
  225. training program uses GPUs for training and you would like to use
  226. :func:`torch.nn.parallel.DistributedDataParallel` module,
  227. here is how to configure it.
  228. ::
  229. local_rank = int(os.environ["LOCAL_RANK"])
  230. model = torch.nn.parallel.DistributedDataParallel(
  231. model, device_ids=[local_rank], output_device=local_rank
  232. )
  233. Please ensure that ``device_ids`` argument is set to be the only GPU device id
  234. that your code will be operating on. This is generally the local rank of the
  235. process. In other words, the ``device_ids`` needs to be ``[int(os.environ("LOCAL_RANK"))]``,
  236. and ``output_device`` needs to be ``int(os.environ("LOCAL_RANK"))`` in order to use this
  237. utility
  238. 4. On failures or membership changes ALL surviving workers are killed immediately. Make sure to
  239. checkpoint your progress. The frequency of checkpoints should depend on your job's tolerance
  240. for lost work.
  241. 5. This module only supports homogeneous ``LOCAL_WORLD_SIZE``. That is, it is assumed that all
  242. nodes run the same number of local workers (per role).
  243. 6. ``RANK`` is NOT stable. Between restarts, the local workers on a node can be assigned a
  244. different range of ranks than before. NEVER hard code any assumptions about the stable-ness of
  245. ranks or some correlation between ``RANK`` and ``LOCAL_RANK``.
  246. 7. When using elasticity (``min_size!=max_size``) DO NOT hard code assumptions about
  247. ``WORLD_SIZE`` as the world size can change as nodes are allowed to leave and join.
  248. 8. It is recommended for your script to have the following structure:
  249. ::
  250. def main():
  251. load_checkpoint(checkpoint_path)
  252. initialize()
  253. train()
  254. def train():
  255. for batch in iter(dataset):
  256. train_step(batch)
  257. if should_checkpoint:
  258. save_checkpoint(checkpoint_path)
  259. 9. (Recommended) On worker errors, this tool will summarize the details of the error
  260. (e.g. time, rank, host, pid, traceback, etc). On each node, the first error (by timestamp)
  261. is heuristically reported as the "Root Cause" error. To get tracebacks as part of this
  262. error summary print out, you must decorate your main entrypoint function in your
  263. training script as shown in the example below. If not decorated, then the summary
  264. will not include the traceback of the exception and will only contain the exitcode.
  265. For details on torchelastic error handling see: https://pytorch.org/docs/stable/elastic/errors.html
  266. ::
  267. from torch.distributed.elastic.multiprocessing.errors import record
  268. @record
  269. def main():
  270. # do train
  271. pass
  272. if __name__ == "__main__":
  273. main()
  274. """ # noqa: E501
  275. import os
  276. import sys
  277. import uuid
  278. from argparse import ArgumentParser, REMAINDER
  279. from collections.abc import Callable
  280. from importlib import metadata
  281. import torch
  282. from torch.distributed.argparse_util import check_env, env
  283. from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, Std
  284. from torch.distributed.elastic.multiprocessing.errors import record
  285. from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
  286. from torch.distributed.elastic.utils import macros
  287. from torch.distributed.elastic.utils.logging import get_logger
  288. from torch.distributed.launcher.api import elastic_launch, LaunchConfig
  289. from torch.numa.binding import (
  290. AffinityMode as _AffinityMode, # Signify as private with _
  291. NumaOptions as _NumaOptions,
  292. )
  293. from torch.utils.backend_registration import _get_custom_mod_func
  294. logger = get_logger(__name__)
  295. def get_args_parser() -> ArgumentParser:
  296. """Parse the command line options."""
  297. parser = ArgumentParser(description="Torch Distributed Elastic Training Launcher")
  298. def comma_separated_list(value):
  299. placeholder = "<COMMA_PLACEHOLDER>"
  300. value = value.replace(",,", placeholder)
  301. items = value.split(",")
  302. items = [item.replace(placeholder, ",") for item in items]
  303. return items
  304. #
  305. # Worker/node size related arguments.
  306. #
  307. parser.add_argument(
  308. "--nnodes",
  309. action=env,
  310. type=str,
  311. default="1:1",
  312. help="Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.",
  313. )
  314. parser.add_argument(
  315. "--nproc-per-node",
  316. "--nproc_per_node",
  317. action=env,
  318. type=str,
  319. default="1",
  320. help="Number of workers per node; supported values: [auto, cpu, gpu, xpu, int].",
  321. )
  322. #
  323. # Rendezvous related arguments
  324. #
  325. parser.add_argument(
  326. "--rdzv-backend",
  327. "--rdzv_backend",
  328. action=env,
  329. type=str,
  330. default="static",
  331. help="Rendezvous backend.",
  332. )
  333. parser.add_argument(
  334. "--rdzv-endpoint",
  335. "--rdzv_endpoint",
  336. action=env,
  337. type=str,
  338. default="",
  339. help="Rendezvous backend endpoint; usually in form <host>:<port>.",
  340. )
  341. parser.add_argument(
  342. "--rdzv-id",
  343. "--rdzv_id",
  344. action=env,
  345. type=str,
  346. default="none",
  347. help="User-defined group id.",
  348. )
  349. parser.add_argument(
  350. "--rdzv-conf",
  351. "--rdzv_conf",
  352. action=env,
  353. type=str,
  354. default="",
  355. help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
  356. )
  357. parser.add_argument(
  358. "--standalone",
  359. action=check_env,
  360. help="Start a local standalone rendezvous backend that is represented by a C10d TCP store "
  361. "on a free port. Useful when launching single-node, multi-worker job. If specified "
  362. "--rdzv-backend, --rdzv-endpoint, --rdzv-id are auto-assigned and any explicitly set values "
  363. "are ignored.",
  364. )
  365. #
  366. # User-code launch related arguments.
  367. #
  368. parser.add_argument(
  369. "--max-restarts",
  370. "--max_restarts",
  371. action=env,
  372. type=int,
  373. default=0,
  374. help="Maximum number of worker group restarts before failing.",
  375. )
  376. parser.add_argument(
  377. "--monitor-interval",
  378. "--monitor_interval",
  379. action=env,
  380. type=float,
  381. default=0.1,
  382. help="Interval, in seconds, to monitor the state of workers.",
  383. )
  384. parser.add_argument(
  385. "--start-method",
  386. "--start_method",
  387. action=env,
  388. type=str,
  389. default="spawn",
  390. choices=["spawn", "fork", "forkserver"],
  391. help="Multiprocessing start method to use when creating workers.",
  392. )
  393. parser.add_argument(
  394. "--event-log-handler",
  395. "--event_log_handler",
  396. action=env,
  397. type=str,
  398. default="null",
  399. help="name of a registered event logging handler (see: https://docs.pytorch.org/docs/stable/elastic/events.html)",
  400. )
  401. parser.add_argument(
  402. "--role",
  403. action=env,
  404. type=str,
  405. default="default",
  406. help="User-defined role for the workers.",
  407. )
  408. parser.add_argument(
  409. "-m",
  410. "--module",
  411. action=check_env,
  412. help="Change each process to interpret the launch script as a Python module, executing "
  413. "with the same behavior as 'python -m'.",
  414. )
  415. parser.add_argument(
  416. "--no-python",
  417. "--no_python",
  418. action=check_env,
  419. help="Skip prepending the training script with 'python' - just execute it directly. Useful "
  420. "when the script is not a Python script.",
  421. )
  422. parser.add_argument(
  423. "--run-path",
  424. "--run_path",
  425. action=check_env,
  426. help="Run the training script with runpy.run_path in the same interpreter."
  427. " Script must be provided as an abs path (e.g. /abs/path/script.py)."
  428. " Takes precedence over --no-python.",
  429. )
  430. parser.add_argument(
  431. "--log-dir",
  432. "--log_dir",
  433. action=env,
  434. type=str,
  435. default=None,
  436. help="Base directory to use for log files (e.g. /var/log/torch/elastic). The same "
  437. "directory is reused for multiple runs (a unique job-level sub-directory is created with "
  438. "rdzv_id as the prefix).",
  439. )
  440. parser.add_argument(
  441. "-r",
  442. "--redirects",
  443. action=env,
  444. type=str,
  445. default="0",
  446. help="Redirect std streams into a log file in the log directory (e.g. [-r 3] redirects "
  447. "both stdout+stderr for all workers, [-r 0:1,1:2] redirects stdout for local rank 0 and "
  448. "stderr for local rank 1).",
  449. )
  450. parser.add_argument(
  451. "-t",
  452. "--tee",
  453. action=env,
  454. type=str,
  455. default="0",
  456. help="Tee std streams into a log file and also to console (see --redirects for format).",
  457. )
  458. parser.add_argument(
  459. "--local-ranks-filter",
  460. "--local_ranks_filter",
  461. action=env,
  462. type=str,
  463. default="",
  464. help="Only show logs from specified ranks in console (e.g. [--local_ranks_filter=0,1,2] will "
  465. "only show logs from rank 0, 1 and 2). This will only apply to stdout and stderr, not to"
  466. "log files saved via --redirect or --tee",
  467. )
  468. parser.add_argument(
  469. "--duplicate-stdout-filters",
  470. "--duplicate_stdout_filters",
  471. action=env,
  472. type=comma_separated_list,
  473. default=[],
  474. help="Duplicates logs streamed to stdout to another specified file with a list of filters (e.g. "
  475. "[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' "
  476. "OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ",
  477. )
  478. parser.add_argument(
  479. "--duplicate-stderr-filters",
  480. "--duplicate_stderr_filters",
  481. action=env,
  482. type=comma_separated_list,
  483. default=[],
  484. help="Duplicates logs streamed to stderr to another specified file with a list of filters (e.g. "
  485. "[--duplicate_stdout_filters 'apple,orange'] will duplicate log lines matching 'apple' "
  486. "OR 'orange'. An empty filters list won't duplicate any lines. Use double comma to escape a comma) ",
  487. )
  488. #
  489. # Backwards compatible parameters with caffe2.distributed.launch.
  490. #
  491. parser.add_argument(
  492. "--node-rank",
  493. "--node_rank",
  494. type=int,
  495. action=env,
  496. default=0,
  497. help="Rank of the node for multi-node distributed training.",
  498. )
  499. parser.add_argument(
  500. "--master-addr",
  501. "--master_addr",
  502. default="127.0.0.1",
  503. type=str,
  504. action=env,
  505. help="Address of the master node (rank 0) that only used for static rendezvous. It should "
  506. "be either the IP address or the hostname of rank 0. For single node multi-proc training "
  507. "the --master-addr can simply be 127.0.0.1; IPv6 should have the pattern "
  508. "`[0:0:0:0:0:0:0:1]`.",
  509. )
  510. parser.add_argument(
  511. "--master-port",
  512. "--master_port",
  513. default=29500,
  514. type=int,
  515. action=env,
  516. help="Port on the master node (rank 0) to be used for communication during distributed "
  517. "training. It is only used for static rendezvous.",
  518. )
  519. parser.add_argument(
  520. "--local-addr",
  521. "--local_addr",
  522. default=None,
  523. type=str,
  524. action=env,
  525. help="Address of the local node. If specified, will use the given address for connection. "
  526. "Else, will look up the local node address instead. Else, it will be default to local "
  527. "machine's FQDN.",
  528. )
  529. parser.add_argument(
  530. "--logs-specs",
  531. "--logs_specs",
  532. default=None,
  533. type=str,
  534. help="torchrun.logs_specs group entrypoint name, value must be type of LogsSpecs. "
  535. "Can be used to override custom logging behavior.",
  536. )
  537. parser.add_argument(
  538. "--numa-binding",
  539. "--numa_binding",
  540. type=str,
  541. choices=[mode.value for mode in _AffinityMode],
  542. default=None,
  543. help="Bind worker processes to CPUs near their assigned GPUs for better performance. "
  544. "See torch/numa/binding.py for available modes and details.",
  545. )
  546. parser.add_argument(
  547. "--signals-to-handle",
  548. "--signals_to_handle",
  549. action=env,
  550. type=str,
  551. default="SIGTERM,SIGINT,SIGHUP,SIGQUIT",
  552. help="Comma-separated list of signals to handle and forward to subprocesses. "
  553. "Default: SIGTERM,SIGINT,SIGHUP,SIGQUIT. "
  554. "Common additional signals: SIGUSR1,SIGUSR2 (used in SLURM environments).",
  555. )
  556. parser.add_argument(
  557. "--virtual-local-rank",
  558. "--virtual_local_rank",
  559. action=check_env,
  560. help="Enable virtual local rank mode for workers. When enabled, LOCAL_RANK is set to 0 "
  561. "for all workers and CUDA_VISIBLE_DEVICES is adjusted so each worker accesses its "
  562. "assigned GPU at device index 0.",
  563. )
  564. #
  565. # Positional arguments.
  566. #
  567. parser.add_argument(
  568. "training_script",
  569. type=str,
  570. help="Full path to the (single GPU) training program/script to be launched in parallel, "
  571. "followed by all the arguments for the training script.",
  572. )
  573. # Rest from the training program.
  574. parser.add_argument("training_script_args", nargs=REMAINDER)
  575. return parser
  576. def parse_args(args):
  577. parser = get_args_parser()
  578. return parser.parse_args(args)
  579. def parse_min_max_nnodes(nnodes: str):
  580. arr = nnodes.split(":")
  581. if len(arr) == 1:
  582. min_nodes = max_nodes = int(arr[0])
  583. elif len(arr) == 2:
  584. min_nodes = int(arr[0])
  585. max_nodes = int(arr[1])
  586. else:
  587. raise RuntimeError(f'nnodes={nnodes} is not in "MIN:MAX" format') # noqa: E231
  588. return min_nodes, max_nodes
  589. def determine_local_world_size(nproc_per_node: str):
  590. try:
  591. logger.info("Using nproc_per_node=%s.", nproc_per_node)
  592. return int(nproc_per_node)
  593. except ValueError as e:
  594. if nproc_per_node == "cpu":
  595. num_proc = os.cpu_count()
  596. device_type = "cpu"
  597. elif nproc_per_node == "gpu":
  598. if not torch.cuda.is_available():
  599. raise ValueError("Cuda is not available.") from e
  600. device_type = "gpu"
  601. num_proc = torch.cuda.device_count()
  602. elif nproc_per_node == "xpu":
  603. if not torch.xpu.is_available():
  604. raise ValueError("Xpu is not available.") from e
  605. device_type = "xpu"
  606. num_proc = torch.xpu.device_count()
  607. elif nproc_per_node == torch._C._get_privateuse1_backend_name():
  608. if not _get_custom_mod_func("is_available")():
  609. raise ValueError(f"{nproc_per_node} is not available.") from e
  610. device_type = nproc_per_node
  611. num_proc = _get_custom_mod_func("device_count")()
  612. elif nproc_per_node == "auto":
  613. if torch.accelerator.is_available():
  614. num_proc = torch.accelerator.device_count()
  615. device_type = torch.accelerator.current_accelerator().type # type: ignore[union-attr]
  616. else:
  617. num_proc = os.cpu_count()
  618. device_type = "cpu"
  619. else:
  620. raise ValueError(
  621. f"Unsupported nproc_per_node value: {nproc_per_node}"
  622. ) from e
  623. logger.info(
  624. "Using nproc_per_node=%s, setting nproc_per_node to %s since the instance has %s %s",
  625. nproc_per_node,
  626. num_proc,
  627. num_proc,
  628. device_type,
  629. )
  630. return num_proc
  631. def get_rdzv_endpoint(args):
  632. if args.rdzv_backend == "static" and not args.rdzv_endpoint:
  633. return f"{args.master_addr}:{args.master_port}" # noqa: E231
  634. return args.rdzv_endpoint
  635. def get_use_env(args) -> bool:
  636. """
  637. Retrieve ``use_env`` from the args.
  638. ``use_env`` is a legacy argument, if ``use_env`` is False, the
  639. ``--node-rank`` argument will be transferred to all worker processes.
  640. ``use_env`` is only used by the ``torch.distributed.launch`` and will
  641. be deprecated in future releases.
  642. """
  643. if not hasattr(args, "use_env"):
  644. return True
  645. return args.use_env
  646. def _get_logs_specs_class(logs_specs_name: str | None) -> type[LogsSpecs]:
  647. """
  648. Attempts to load `torchrun.logs_spec` entrypoint with key of `logs_specs_name` param.
  649. Provides plugin mechanism to provide custom implementation of LogsSpecs.
  650. Returns `DefaultLogsSpecs` when logs_spec_name is None.
  651. Raises ValueError when entrypoint for `logs_spec_name` can't be found in entrypoints.
  652. """
  653. logs_specs_cls = None
  654. if logs_specs_name is not None:
  655. eps = metadata.entry_points()
  656. group = eps.select(group="torchrun.logs_specs")
  657. if group.select(name=logs_specs_name):
  658. # pyrefly: ignore [bad-index]
  659. logs_specs_cls = group[logs_specs_name].load()
  660. if logs_specs_cls is None:
  661. raise ValueError(
  662. f"Could not find entrypoint under 'torchrun.logs_specs[{logs_specs_name}]' key"
  663. )
  664. logger.info(
  665. "Using logs_spec '%s' mapped to %s", logs_specs_name, str(logs_specs_cls)
  666. )
  667. else:
  668. logs_specs_cls = DefaultLogsSpecs
  669. return logs_specs_cls
  670. def config_from_args(args) -> tuple[LaunchConfig, Callable | str, list[str]]:
  671. # If ``args`` not passed, defaults to ``sys.argv[:1]``
  672. min_nodes, max_nodes = parse_min_max_nnodes(args.nnodes)
  673. if not (0 < min_nodes <= max_nodes):
  674. raise AssertionError(
  675. f"min_nodes must be > 0 and <= max_nodes, got min_nodes={min_nodes}, max_nodes={max_nodes}"
  676. )
  677. if args.max_restarts < 0:
  678. raise AssertionError("max_restarts must be >= 0")
  679. if (
  680. hasattr(args, "master_addr")
  681. and args.rdzv_backend != "static"
  682. and not args.rdzv_endpoint
  683. ):
  684. logger.warning(
  685. "master_addr is only used for static rdzv_backend and when rdzv_endpoint "
  686. "is not specified."
  687. )
  688. nproc_per_node = determine_local_world_size(args.nproc_per_node)
  689. if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
  690. omp_num_threads = 1
  691. logger.warning(
  692. "\n*****************************************\n"
  693. "Setting OMP_NUM_THREADS environment variable for each process to be "
  694. "%s in default, to avoid your system being overloaded, "
  695. "please further tune the variable for optimal performance in "
  696. "your application as needed. \n"
  697. "*****************************************",
  698. omp_num_threads,
  699. )
  700. # This env variable will be passed down to the subprocesses
  701. os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
  702. log_line_prefix_template = os.getenv("TORCHELASTIC_LOG_LINE_PREFIX_TEMPLATE")
  703. rdzv_configs = _parse_rendezvous_config(args.rdzv_conf)
  704. if args.rdzv_backend == "static":
  705. rdzv_configs["rank"] = args.node_rank
  706. rdzv_endpoint = get_rdzv_endpoint(args)
  707. ranks: set[int] | None = None
  708. if args.local_ranks_filter:
  709. try:
  710. ranks = set(map(int, args.local_ranks_filter.split(",")))
  711. if not ranks:
  712. raise AssertionError("ranks set cannot be empty")
  713. except Exception as e:
  714. raise ValueError(
  715. "--local_ranks_filter must be a comma-separated list of integers e.g. --local_ranks_filter=0,1,2"
  716. ) from e
  717. logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
  718. logs_specs = logs_specs_cls(
  719. log_dir=args.log_dir,
  720. redirects=Std.from_str(args.redirects),
  721. tee=Std.from_str(args.tee),
  722. local_ranks_filter=ranks,
  723. )
  724. numa_options = (
  725. None
  726. if args.numa_binding is None
  727. else _NumaOptions(affinity_mode=_AffinityMode(args.numa_binding))
  728. )
  729. config = LaunchConfig(
  730. min_nodes=min_nodes,
  731. max_nodes=max_nodes,
  732. nproc_per_node=nproc_per_node,
  733. run_id=args.rdzv_id,
  734. role=args.role,
  735. rdzv_endpoint=rdzv_endpoint,
  736. rdzv_backend=args.rdzv_backend,
  737. rdzv_configs=rdzv_configs,
  738. max_restarts=args.max_restarts,
  739. monitor_interval=args.monitor_interval,
  740. start_method=args.start_method,
  741. log_line_prefix_template=log_line_prefix_template,
  742. local_addr=args.local_addr,
  743. logs_specs=logs_specs,
  744. event_log_handler=args.event_log_handler,
  745. numa_options=numa_options,
  746. signals_to_handle=args.signals_to_handle,
  747. duplicate_stdout_filters=args.duplicate_stdout_filters,
  748. duplicate_stderr_filters=args.duplicate_stderr_filters,
  749. virtual_local_rank=args.virtual_local_rank,
  750. )
  751. with_python = not args.no_python
  752. cmd: Callable | str
  753. cmd_args = []
  754. use_env = get_use_env(args)
  755. if args.run_path:
  756. cmd = run_script_path
  757. cmd_args.append(args.training_script)
  758. else:
  759. if with_python:
  760. cmd = os.getenv("PYTHON_EXEC", sys.executable)
  761. cmd_args.append("-u")
  762. if args.module:
  763. cmd_args.append("-m")
  764. cmd_args.append(args.training_script)
  765. else:
  766. if args.module:
  767. raise ValueError(
  768. "Don't use both the '--no-python' flag"
  769. " and the '--module' flag at the same time."
  770. )
  771. cmd = args.training_script
  772. if not use_env:
  773. cmd_args.append(f"--local-rank={macros.local_rank}")
  774. cmd_args.extend(args.training_script_args)
  775. return config, cmd, cmd_args
  776. def run_script_path(training_script: str, *training_script_args: str):
  777. """
  778. Run the provided `training_script` from within this interpreter.
  779. Usage: `script_as_function("/abs/path/to/script.py", "--arg1", "val1")`
  780. """
  781. import runpy
  782. import sys
  783. sys.argv = [training_script] + [*training_script_args]
  784. runpy.run_path(sys.argv[0], run_name="__main__")
  785. def run(args):
  786. torch.multiprocessing._set_thread_name("pt_elastic")
  787. if args.standalone:
  788. args.rdzv_backend = "c10d"
  789. args.rdzv_endpoint = "localhost:0"
  790. args.rdzv_id = str(uuid.uuid4())
  791. logger.info(
  792. "\n**************************************\n"
  793. "Rendezvous info:\n"
  794. "--rdzv-backend=%s "
  795. "--rdzv-endpoint=%s "
  796. "--rdzv-id=%s\n"
  797. "**************************************\n",
  798. args.rdzv_backend,
  799. args.rdzv_endpoint,
  800. args.rdzv_id,
  801. )
  802. config, cmd, cmd_args = config_from_args(args)
  803. elastic_launch(
  804. config=config,
  805. entrypoint=cmd,
  806. )(*cmd_args)
  807. @record
  808. def main(args=None):
  809. args = parse_args(args)
  810. run(args)
  811. if __name__ == "__main__":
  812. main()