constants.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from pathlib import Path
  2. from typing import Any
  3. import ray
  4. from ray._private.ray_constants import env_bool
  5. from ray.air.constants import ( # noqa: F401
  6. COPY_DIRECTORY_CHECKPOINTS_INSTEAD_OF_MOVING_ENV,
  7. EVALUATION_DATASET_KEY,
  8. MODEL_KEY,
  9. PREPROCESSOR_KEY,
  10. TRAIN_DATASET_KEY,
  11. )
  12. def _get_ray_train_session_dir() -> str:
  13. assert ray.is_initialized(), "Ray must be initialized to get the session dir."
  14. return Path(
  15. ray._private.worker._global_node.get_session_dir_path(), "artifacts"
  16. ).as_posix()
  17. DEFAULT_STORAGE_PATH = Path("~/ray_results").expanduser().as_posix()
  18. # Autofilled ray.train.report() metrics. Keys should be consistent with Tune.
  19. CHECKPOINT_DIR_NAME = "checkpoint_dir_name"
  20. TIME_TOTAL_S = "_time_total_s"
  21. WORKER_HOSTNAME = "_hostname"
  22. WORKER_NODE_IP = "_node_ip"
  23. WORKER_PID = "_pid"
  24. # Will not be reported unless ENABLE_DETAILED_AUTOFILLED_METRICS_ENV
  25. # env var is not 0
  26. DETAILED_AUTOFILLED_KEYS = {WORKER_HOSTNAME, WORKER_NODE_IP, WORKER_PID, TIME_TOTAL_S}
  27. # Default filename for JSON logger
  28. RESULT_FILE_JSON = "results.json"
  29. # The name of the subdirectory inside the trainer run_dir to store checkpoints.
  30. TRAIN_CHECKPOINT_SUBDIR = "checkpoints"
  31. # The key to use to specify the checkpoint id for Tune.
  32. # This needs to be added to the checkpoint dictionary so if the Tune trial
  33. # is restarted, the checkpoint_id can continue to increment.
  34. TUNE_CHECKPOINT_ID = "_current_checkpoint_id"
  35. # Deprecated configs can use this value to detect if the user has set it.
  36. # This has type Any to allow it to be assigned to any annotated parameter
  37. # without causing type errors.
  38. _DEPRECATED_VALUE: Any = "DEPRECATED"
  39. # ==================================================
  40. # Train V2 constants
  41. # ==================================================
  42. # Set this to 1 to enable deprecation warnings for V2 migration.
  43. ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR = "RAY_TRAIN_ENABLE_V2_MIGRATION_WARNINGS"
  44. V2_MIGRATION_GUIDE_MESSAGE = (
  45. "See this issue for more context and migration options: "
  46. "https://github.com/ray-project/ray/issues/49454. "
  47. "Disable these warnings by setting the environment variable: "
  48. f"{ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR}=0"
  49. )
  50. def _v2_migration_warnings_enabled() -> bool:
  51. return env_bool(ENABLE_V2_MIGRATION_WARNINGS_ENV_VAR, True)
  52. # ==================================================
  53. # Environment Variables
  54. # ==================================================
  55. ENABLE_DETAILED_AUTOFILLED_METRICS_ENV = (
  56. "TRAIN_RESULT_ENABLE_DETAILED_AUTOFILLED_METRICS"
  57. )
  58. # Integer value which if set will override the value of
  59. # Backend.share_cuda_visible_devices. 1 for True, 0 for False.
  60. ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_CUDA_VISIBLE_DEVICES"
  61. # Integer value which if set will not share HIP accelerator visible devices
  62. # across workers. 1 for True (default), 0 for False.
  63. ENABLE_SHARE_HIP_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_HIP_VISIBLE_DEVICES"
  64. # Integer value which if set will not share neuron-core accelerator visible cores
  65. # across workers. 1 for True (default), 0 for False.
  66. ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV = (
  67. "TRAIN_ENABLE_SHARE_NEURON_CORES_ACCELERATOR"
  68. )
  69. # Integer value which if set will not share npu visible devices
  70. # across workers. 1 for True (default), 0 for False.
  71. ENABLE_SHARE_NPU_RT_VISIBLE_DEVICES_ENV = "TRAIN_ENABLE_SHARE_ASCEND_RT_VISIBLE_DEVICES"
  72. # Integer value which indicates the number of seconds to wait when creating
  73. # the worker placement group before timing out.
  74. TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV = "TRAIN_PLACEMENT_GROUP_TIMEOUT_S"
  75. # Integer value which if set will change the placement group strategy from
  76. # PACK to SPREAD. 1 for True, 0 for False.
  77. TRAIN_ENABLE_WORKER_SPREAD_ENV = "TRAIN_ENABLE_WORKER_SPREAD"
  78. # Set this to 0 to disable changing the working directory of each Tune Trainable
  79. # or Train worker to the trial directory. Defaults to 1.
  80. RAY_CHDIR_TO_TRIAL_DIR = "RAY_CHDIR_TO_TRIAL_DIR"
  81. # Set this to 1 to count preemption errors toward `FailureConfig(max_failures)`.
  82. # Defaults to 0, which always retries on node preemption failures.
  83. RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE = "RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE"
  84. # Set this to 1 to start a StateActor and collect information Train Runs
  85. # Defaults to 0
  86. RAY_TRAIN_ENABLE_STATE_TRACKING = "RAY_TRAIN_ENABLE_STATE_TRACKING"
  87. # Set this to 1 to only store the checkpoint score attribute with the Checkpoint
  88. # in the CheckpointManager. The Result will only have the checkpoint score attribute
  89. # but files written to disk like result.json will still have all the metrics.
  90. # Defaults to 0.
  91. # TODO: this is a temporary solution to avoid CheckpointManager OOM.
  92. # See https://github.com/ray-project/ray/pull/54642#issue-3234029360 for more details.
  93. TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE = (
  94. "TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE"
  95. )
  96. # Seconds to wait for torch process group to shut down.
  97. # Shutting down a healthy torch process group, which we may want to do for reasons
  98. # like restarting a group of workers if an async checkpoint upload fails, can hang.
  99. # This is a workaround until we figure out how to avoid this hang.
  100. TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = "TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S"
  101. DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = 30
  102. # Seconds to wait for JAX distributed shutdown.
  103. JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = "JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S"
  104. DEFAULT_JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S = 30
  105. # NOTE: When adding a new environment variable, please track it in this list.
  106. TRAIN_ENV_VARS = {
  107. ENABLE_DETAILED_AUTOFILLED_METRICS_ENV,
  108. ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
  109. ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
  110. TRAIN_PLACEMENT_GROUP_TIMEOUT_S_ENV,
  111. TRAIN_ENABLE_WORKER_SPREAD_ENV,
  112. RAY_CHDIR_TO_TRIAL_DIR,
  113. RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
  114. RAY_TRAIN_ENABLE_STATE_TRACKING,
  115. TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE,
  116. TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
  117. JAX_DISTRIBUTED_SHUTDOWN_TIMEOUT_S,
  118. }
  119. # Key for AIR Checkpoint metadata in TrainingResult metadata
  120. CHECKPOINT_METADATA_KEY = "checkpoint_metadata"
  121. # Key for AIR Checkpoint world rank in TrainingResult metadata
  122. CHECKPOINT_RANK_KEY = "checkpoint_rank"