start_ray_node.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. import fcntl
  2. import logging
  3. import os.path
  4. import shutil
  5. import signal
  6. import socket
  7. import subprocess
  8. import sys
  9. import threading
  10. import time
  11. from ray._private.ray_process_reaper import SIGTERM_GRACE_PERIOD_SECONDS
  12. from ray.util.spark.cluster_init import (
  13. RAY_ON_SPARK_COLLECT_LOG_TO_PATH,
  14. RAY_ON_SPARK_START_RAY_PARENT_PID,
  15. )
  16. # Spark on ray implementation does not directly invoke `ray start ...` script to create
  17. # ray node subprocess, instead, it creates a subprocess to run this
  18. # `ray.util.spark.start_ray_node` module, and in this module it invokes `ray start ...`
  19. # script to start ray node, the purpose of `start_ray_node` module is to set up a
  20. # exit handler for cleaning ray temp directory when ray node exits.
  21. # When spark driver python process dies, or spark python worker dies, because
  22. # `start_ray_node` starts a daemon thread of `check_parent_alive`, it will detect
  23. # parent process died event and then trigger cleanup work.
  24. _logger = logging.getLogger(__name__)
  25. if __name__ == "__main__":
  26. arg_list = sys.argv[1:]
  27. collect_log_to_path = os.environ[RAY_ON_SPARK_COLLECT_LOG_TO_PATH]
  28. temp_dir_arg_prefix = "--temp-dir="
  29. temp_dir = None
  30. for arg in arg_list:
  31. if arg.startswith(temp_dir_arg_prefix):
  32. temp_dir = arg[len(temp_dir_arg_prefix) :]
  33. if temp_dir is not None:
  34. temp_dir = os.path.normpath(temp_dir)
  35. else:
  36. # This case is for global mode Ray on spark cluster
  37. from ray.util.spark.cluster_init import _get_default_ray_tmp_dir
  38. temp_dir = _get_default_ray_tmp_dir()
  39. # Multiple Ray nodes might be launched in the same machine,
  40. # so set `exist_ok` to True
  41. os.makedirs(temp_dir, exist_ok=True)
  42. ray_cli_cmd = "ray"
  43. lock_file = temp_dir + ".lock"
  44. lock_fd = os.open(lock_file, os.O_RDWR | os.O_CREAT | os.O_TRUNC)
  45. # Mutilple ray nodes might start on the same machine, and they are using the
  46. # same temp directory, adding a shared lock representing current ray node is
  47. # using the temp directory.
  48. fcntl.flock(lock_fd, fcntl.LOCK_SH)
  49. process = subprocess.Popen(
  50. # 'ray start ...' command uses python that is set by
  51. # Shebang #! ..., the Shebang line is hardcoded in ray script,
  52. # it can't be changed to other python executable path.
  53. # to enforce using current python executable,
  54. # turn the subprocess command to
  55. # '`sys.executable` `which ray` start ...'
  56. [sys.executable, shutil.which(ray_cli_cmd), "start", *arg_list],
  57. text=True,
  58. )
  59. exit_handler_executed = False
  60. sigterm_handler_executed = False
  61. ON_EXIT_HANDLER_WAIT_TIME = 3
  62. def on_exit_handler():
  63. global exit_handler_executed
  64. if exit_handler_executed:
  65. # wait for exit_handler execution completed in other threads.
  66. time.sleep(ON_EXIT_HANDLER_WAIT_TIME)
  67. return
  68. exit_handler_executed = True
  69. try:
  70. # Wait for a while to ensure the children processes of the ray node all
  71. # exited.
  72. time.sleep(SIGTERM_GRACE_PERIOD_SECONDS + 0.5)
  73. if process.poll() is None:
  74. # "ray start ..." command process is still alive. Force to kill it.
  75. process.kill()
  76. # Release the shared lock, representing current ray node does not use the
  77. # temp dir.
  78. fcntl.flock(lock_fd, fcntl.LOCK_UN)
  79. try:
  80. # acquiring exclusive lock to ensure copy logs and removing dir safely.
  81. fcntl.flock(lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
  82. lock_acquired = True
  83. except BlockingIOError:
  84. # The file has active shared lock or exclusive lock, representing there
  85. # are other ray nodes running, or other node running cleanup temp-dir
  86. # routine. skip cleaning temp-dir, and skip copy logs to destination
  87. # directory as well.
  88. lock_acquired = False
  89. if lock_acquired:
  90. # This is the final terminated ray node on current spark worker,
  91. # start copy logs (including all local ray nodes logs) to destination.
  92. if collect_log_to_path:
  93. try:
  94. log_dir_prefix = os.path.basename(temp_dir)
  95. if log_dir_prefix == "ray":
  96. # global mode cluster case, append a timestamp to it to
  97. # avoid name conflict with last Ray global cluster log dir.
  98. log_dir_prefix = (
  99. log_dir_prefix + f"-global-{int(time.time())}"
  100. )
  101. base_dir = os.path.join(
  102. collect_log_to_path, log_dir_prefix + "-logs"
  103. )
  104. # Note: multiple Ray node launcher process might
  105. # execute this line code, so we set exist_ok=True here.
  106. os.makedirs(base_dir, exist_ok=True)
  107. copy_log_dest_path = os.path.join(
  108. base_dir,
  109. socket.gethostname(),
  110. )
  111. ray_session_dir = os.readlink(
  112. os.path.join(temp_dir, "session_latest")
  113. )
  114. shutil.copytree(
  115. os.path.join(ray_session_dir, "logs"),
  116. copy_log_dest_path,
  117. )
  118. except Exception as e:
  119. _logger.warning(
  120. "Collect logs to destination directory failed, "
  121. f"error: {repr(e)}."
  122. )
  123. # Start cleaning the temp-dir,
  124. shutil.rmtree(temp_dir, ignore_errors=True)
  125. except Exception:
  126. # swallow any exception.
  127. pass
  128. finally:
  129. fcntl.flock(lock_fd, fcntl.LOCK_UN)
  130. os.close(lock_fd)
  131. def check_parent_alive() -> None:
  132. orig_parent_pid = int(os.environ[RAY_ON_SPARK_START_RAY_PARENT_PID])
  133. while True:
  134. time.sleep(0.5)
  135. if os.getppid() != orig_parent_pid:
  136. # Note raising SIGTERM signal in a background thread
  137. # doesn't work
  138. sigterm_handler()
  139. break
  140. threading.Thread(target=check_parent_alive, daemon=True).start()
  141. try:
  142. def sighup_handler(*args):
  143. pass
  144. # When spark application is terminated, this process will receive
  145. # SIGHUP (comes from pyspark application termination).
  146. # Ignore the SIGHUP signal, because in this case,
  147. # `check_parent_alive` will capture parent process died event
  148. # and execute killing node and cleanup routine
  149. # but if we enable default SIGHUP handler, it will kill
  150. # the process immediately and it causes `check_parent_alive`
  151. # have no time to exeucte cleanup routine.
  152. signal.signal(signal.SIGHUP, sighup_handler)
  153. def sigterm_handler(*args):
  154. global sigterm_handler_executed
  155. if not sigterm_handler_executed:
  156. sigterm_handler_executed = True
  157. process.terminate()
  158. on_exit_handler()
  159. else:
  160. # wait for exit_handler execution completed in other threads.
  161. time.sleep(ON_EXIT_HANDLER_WAIT_TIME)
  162. # Sigterm exit code is 143.
  163. os._exit(143)
  164. signal.signal(signal.SIGTERM, sigterm_handler)
  165. while True:
  166. try:
  167. ret_code = process.wait()
  168. break
  169. except KeyboardInterrupt:
  170. # Jupyter notebook interrupt button triggers SIGINT signal and
  171. # `start_ray_node` (subprocess) will receive SIGINT signal and it
  172. # causes KeyboardInterrupt exception being raised.
  173. pass
  174. on_exit_handler()
  175. sys.exit(ret_code)
  176. except Exception:
  177. on_exit_handler()
  178. raise