config.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. import logging
  2. import os
  3. from dataclasses import dataclass
  4. from typing import Optional
  5. import ray
  6. from ray._private import ray_constants
  7. from ray.train._internal.utils import get_address_and_port
  8. from ray.train._internal.worker_group import WorkerGroup
  9. from ray.train.backend import Backend, BackendConfig
  10. from ray.train.constants import (
  11. DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
  12. JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
  13. )
  14. from ray.util import PublicAPI
  15. from ray.util.tpu import get_tpu_coordinator_env_vars, get_tpu_worker_resources
  16. logger = logging.getLogger(__name__)
  17. @PublicAPI(stability="alpha")
  18. @dataclass
  19. class JaxConfig(BackendConfig):
  20. use_tpu: bool = False
  21. use_gpu: bool = False
  22. @property
  23. def backend_cls(self):
  24. return _JaxBackend
  25. def _setup_jax_distributed_environment(
  26. master_addr_with_port: str,
  27. num_workers: int,
  28. index: int,
  29. use_tpu: bool,
  30. use_gpu: bool,
  31. resources_per_worker: dict,
  32. jax_env_vars: Optional[dict] = None,
  33. ):
  34. """Set up distributed Jax training information.
  35. This function should be called on each worker. It sets JAX environment
  36. variables and initializes JAX distributed training.
  37. Args:
  38. master_addr_with_port: The master address with port for coordination.
  39. num_workers: Total number of workers.
  40. index: Index of this worker.
  41. use_tpu: Whether to configure for TPU. If True and JAX_PLATFORMS is not
  42. already set, it will be set to "tpu".
  43. use_gpu: Whether to configure for GPU. If True and JAX_PLATFORMS is not
  44. already set, it will be set to "cuda".
  45. resources_per_worker: The resources per worker.
  46. jax_env_vars: The JAX coordinator env vars to inject for multi-slice. These
  47. values do not override existing values if specified.
  48. """
  49. # Get JAX_PLATFORMS from environment if already set
  50. jax_platforms = os.environ.get("JAX_PLATFORMS", "").lower()
  51. if not jax_platforms and use_tpu:
  52. os.environ["JAX_PLATFORMS"] = "tpu"
  53. jax_platforms = "tpu"
  54. if jax_env_vars:
  55. for k, v in jax_env_vars.items():
  56. # Respect configured JAX env vars if set.
  57. if k not in os.environ:
  58. os.environ[k] = v
  59. if not jax_platforms and use_gpu:
  60. os.environ["JAX_PLATFORMS"] = "cuda"
  61. jax_platforms = "cuda"
  62. if "cuda" in jax_platforms.split(","):
  63. num_gpus_per_worker = resources_per_worker.get("GPU", 0)
  64. os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
  65. str(i) for i in range(num_gpus_per_worker)
  66. )
  67. import jax
  68. if "tpu" in jax_platforms.split(","):
  69. jax.distributed.initialize(master_addr_with_port, num_workers, index)
  70. logger.info("Initialized JAX distributed on TPU.")
  71. if "cuda" in jax_platforms.split(","):
  72. if num_gpus_per_worker > 0:
  73. local_device_ids = list(range(num_gpus_per_worker))
  74. else:
  75. local_device_ids = 0
  76. jax.distributed.initialize(
  77. master_addr_with_port, num_workers, index, local_device_ids
  78. )
  79. logger.info("Initialized JAX distributed on CUDA.")
  80. def _shutdown_jax_distributed():
  81. """Shutdown JAX distributed environment.
  82. This function should be called on each worker during cleanup.
  83. If JAX distributed was not initialized, this is a no-op.
  84. """
  85. try:
  86. import jax
  87. jax.distributed.shutdown()
  88. except Exception as e:
  89. logger.warning(f"Error during JAX distributed shutdown: {e}")
  90. class _JaxBackend(Backend):
  91. def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
  92. if not backend_config.use_tpu and not backend_config.use_gpu:
  93. return
  94. master_addr, master_port = worker_group.execute_single(0, get_address_and_port)
  95. master_addr_with_port = f"{master_addr}:{master_port}"
  96. if backend_config.use_tpu and hasattr(worker_group, "get_worker_group_context"):
  97. num_slices = worker_group.get_worker_group_context().num_slices
  98. else:
  99. num_slices = 1
  100. # Calculate the number of workers per slice for multi-slice env setup.
  101. if backend_config.use_tpu and num_slices > 1:
  102. # Handle the case where a user requests less workers than the total
  103. # capacity of the TPU slice.
  104. scaling_config = worker_group._train_run_context.scaling_config
  105. workers_per_slice, _ = get_tpu_worker_resources(
  106. topology=scaling_config.topology,
  107. accelerator_type=scaling_config.accelerator_type,
  108. resources_per_unit=scaling_config.resources_per_worker,
  109. num_slices=1,
  110. )
  111. else:
  112. # Assume even distribution based on the requested number of workers.
  113. workers_per_slice = max(1, len(worker_group) // num_slices)
  114. # Set up JAX distributed environment on all workers
  115. num_workers_total = len(worker_group)
  116. setup_futures = []
  117. for i in range(num_workers_total):
  118. env_vars = {}
  119. if num_slices > 1:
  120. slice_id = min(i // workers_per_slice, num_slices - 1)
  121. env_vars = get_tpu_coordinator_env_vars(
  122. coordinator_address=master_addr,
  123. num_slices=num_slices,
  124. slice_id=slice_id,
  125. )
  126. setup_futures.append(
  127. worker_group.execute_single_async(
  128. i,
  129. _setup_jax_distributed_environment,
  130. master_addr_with_port=master_addr_with_port,
  131. num_workers=len(worker_group),
  132. index=i,
  133. use_tpu=backend_config.use_tpu,
  134. use_gpu=backend_config.use_gpu,
  135. resources_per_worker=worker_group.get_resources_per_worker(),
  136. jax_env_vars=env_vars,
  137. )
  138. )
  139. ray.get(setup_futures)
  140. def on_shutdown(self, worker_group: WorkerGroup, backend_config: JaxConfig):
  141. """Cleanup JAX distributed resources when shutting down worker group."""
  142. if not backend_config.use_tpu and not backend_config.use_gpu:
  143. return
  144. # Shutdown JAX distributed on all workers
  145. shutdown_futures = worker_group.execute_async(_shutdown_jax_distributed)
  146. timeout_s = ray_constants.env_integer(
  147. JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
  148. DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
  149. )
  150. try:
  151. ray.get(shutdown_futures, timeout=timeout_s)
  152. logger.debug("JAX distributed shutdown completed")
  153. except ray.exceptions.GetTimeoutError:
  154. logger.warning(
  155. f"JAX distributed shutdown timed out after {timeout_s} seconds. "
  156. "This may indicate workers are hung or unresponsive."
  157. )
  158. except Exception as e:
  159. logger.warning(f"Error during JAX distributed shutdown: {e}")