config.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. import logging
  2. import os
  3. import re
  4. import shutil
  5. import uuid
  6. from dataclasses import dataclass
  7. import ray
  8. from ray.train._internal.base_worker_group import BaseWorkerGroup
  9. from ray.train._internal.utils import get_address_and_port
  10. from ray.train.backend import Backend
  11. from ray.train.torch import TorchConfig
  12. from ray.util import PublicAPI
  13. logger = logging.getLogger(__name__)
  14. @PublicAPI(stability="alpha")
  15. @dataclass
  16. class TorchXLAConfig(TorchConfig):
  17. """
  18. Configuration for torch XLA setup.
  19. See https://pytorch.org/xla/release/1.13/index.html for more info.
  20. Currently, only "neuron_cores" accelerator (AwsNeuronXLABackend)
  21. is supported with xrt runtime.
  22. """
  23. neuron_parallel_compile: bool = False
  24. @property
  25. def backend_cls(self):
  26. return _TorchAwsNeuronXLABackend
  27. def _kill_xrt_server():
  28. import subprocess
  29. subprocess.call(["pkill", "-f", "xrt_run_server"])
  30. def _set_xla_env_vars():
  31. # https://pytorch.org/docs/1.13/elastic/run.html#environment-variables
  32. context = ray.train.get_context()
  33. os.environ["LOCAL_RANK"] = str(context.get_local_rank())
  34. os.environ["RANK"] = str(context.get_world_rank())
  35. os.environ["LOCAL_WORLD_SIZE"] = str(context.get_local_world_size())
  36. os.environ["WORLD_SIZE"] = str(context.get_world_size())
  37. os.environ["GROUP_RANK"] = str(context.get_node_rank())
  38. os.environ["GROUP_WORLD_SIZE"] = str(
  39. context.get_world_size() / context.get_local_world_size()
  40. )
  41. os.environ["ROLE_RANK"] = str(context.get_world_rank())
  42. os.environ["ROLE_WORLD_RANK"] = str(context.get_world_rank())
  43. os.environ["ROLE_WORLD_SIZE"] = str(context.get_world_size())
  44. # EFA and XLA setup
  45. # https://github.com/aws/libfabric/blob/master/prov/efa/src/rxr/rxr_init.c
  46. # https://github.com/aws-neuron/aws-neuron-samples/blob/master/torch-neuronx/training/dp_bert_hf_pretrain/run_dp_bert_large_hf_pretrain_bf16_s128.sh # noqa
  47. os.environ["FI_PROVIDER"] = "efa"
  48. os.environ["FI_EFA_USE_DEVICE_RDMA"] = "1"
  49. os.environ["FI_EFA_FORK_SAFE"] = "1"
  50. os.environ["XLA_TRANSFER_SEED_ASYNC"] = "1"
  51. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
  52. def _setup_xla_torch_process_group():
  53. try:
  54. import torch.distributed as dist
  55. import torch_xla.core.xla_model as xm # noqa F401
  56. import torch_xla.distributed.xla_backend # noqa F401
  57. dist.init_process_group("xla")
  58. except ImportError:
  59. raise ImportError("torch_xla must be installed to use torch_xla backend.")
  60. # The following env vars enable Neuron graph extraction for parallel compilation
  61. # Note: model outputs are invalid and should be ignored while these env vars are set
  62. def _set_neuron_parallel_compile_env_vars():
  63. os.environ["NEURON_PARALLEL_COMPILE"] = "1"
  64. os.environ["NEURON_EXTRACT_GRAPHS_ONLY"] = "1"
  65. os.environ["NEURON_FALL_BACK_TO_NULL_NEFF"] = "1"
  66. # Compile previously extracted Neuron graphs
  67. def _neuron_compile_extracted_graphs():
  68. try:
  69. from libneuronxla.neuron_cc_cache import CacheUrl
  70. from libneuronxla.neuron_parallel_compile import parallel_compile
  71. except ImportError:
  72. raise ImportError(
  73. "libneuronxla must be installed to use Neuron parallel compilation."
  74. )
  75. # Only 1 worker per node should run parallel_compile()
  76. if os.environ.get("LOCAL_RANK") == "0":
  77. logger.info("Compiling extracted graphs on local rank0 worker")
  78. parallel_compile_workdir = (
  79. f"/tmp/{os.environ.get('USER','no-user')}/parallel_compile_workdir/"
  80. )
  81. if os.path.exists(parallel_compile_workdir):
  82. shutil.rmtree(parallel_compile_workdir)
  83. os.makedirs(parallel_compile_workdir, exist_ok=True)
  84. # Users can set the cache directory using --cache_dir in NEURON_CC_FLAGS or by
  85. # using NEURON_COMPILE_CACHE_URL. --cache_dir takes precedence.
  86. explicit_cache_dir = None
  87. if neuron_cc_flags := os.environ.get("NEURON_CC_FLAGS"):
  88. if s := re.search(r"--cache_dir[= ](\S+)", neuron_cc_flags):
  89. explicit_cache_dir = s.group(1)
  90. parallel_compile(
  91. parallel_compile_workdir,
  92. CacheUrl.get_cache_url(explicit_cache_dir),
  93. )
  94. class _TorchAwsNeuronXLABackend(Backend):
  95. unique_run_id: str = str(uuid.uuid4())
  96. def on_start(self, worker_group: BaseWorkerGroup, backend_config: TorchXLAConfig):
  97. """Logic ran right before training is started."""
  98. # On previous worker failure, we don't run graceful shutdown on workers.
  99. # This would leak any running xrt server.
  100. worker_group.execute(_kill_xrt_server)
  101. # Get master address and port from the first worker.
  102. master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
  103. def set_env_vars(addr, port):
  104. os.environ["MASTER_ADDR"] = addr
  105. os.environ["MASTER_PORT"] = str(port)
  106. # To trigger the xrt server
  107. os.environ["TORCHELASTIC_RUN_ID"] = self.unique_run_id
  108. # Set the env vars on all workers.
  109. worker_group.execute(set_env_vars, addr=master_addr, port=master_port)
  110. # Set up env vars for neuron parallel compilation graph extraction
  111. if backend_config.neuron_parallel_compile:
  112. logger.info("Extracting graphs for Neuron parallel compilation")
  113. worker_group.execute(_set_neuron_parallel_compile_env_vars)
  114. def on_training_start(
  115. self, worker_group: BaseWorkerGroup, backend_config: TorchXLAConfig
  116. ):
  117. """
  118. Configure the environment variables for the worker group.
  119. And initialize the xla distributed process group.
  120. TODO: Current setup only supports homogenous cluster with
  121. neuron_cores accelerator and xrt runtime.
  122. """
  123. worker_group.execute(_set_xla_env_vars)
  124. worker_group.execute(_setup_xla_torch_process_group)
  125. def on_shutdown(
  126. self, worker_group: BaseWorkerGroup, backend_config: TorchXLAConfig
  127. ):
  128. """
  129. Logic ran right after training is finished.
  130. This is a sanity cleanup to kill xrt server, and to optionally
  131. run neuron parallel graph compilation
  132. """
  133. worker_group.execute(_kill_xrt_server)
  134. # Compile the extracted graphs. This must run at end of training.
  135. if backend_config.neuron_parallel_compile:
  136. worker_group.execute(_neuron_compile_extracted_graphs)