module.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. import abc
  2. import asyncio
  3. import inspect
  4. import logging
  5. import multiprocessing
  6. import multiprocessing.connection
  7. import os
  8. import sys
  9. from dataclasses import dataclass
  10. import aiohttp
  11. import ray
  12. from ray import ray_constants
  13. from ray._private import logging_utils
  14. from ray._private.gcs_utils import GcsChannel
  15. from ray._private.ray_logging import setup_component_logger
  16. from ray._raylet import GcsClient
  17. from ray.dashboard.subprocesses.utils import (
  18. get_named_pipe_path,
  19. get_socket_path,
  20. module_logging_filename,
  21. )
  22. logger = logging.getLogger(__name__)
  23. @dataclass
  24. class SubprocessModuleConfig:
  25. """
  26. Configuration for a SubprocessModule.
  27. Pickleable.
  28. """
  29. cluster_id_hex: str
  30. gcs_address: str
  31. session_name: str
  32. temp_dir: str
  33. session_dir: str
  34. # Logger configs. Will be set up in subprocess entrypoint `run_module`.
  35. logging_level: str
  36. logging_format: str
  37. log_dir: str
  38. # Name of the "base" log file. Its stem is appended with the Module.__name__.
  39. # e.g. when logging_filename = "dashboard.log", and Module is JobHead,
  40. # we will set up logger with name "dashboard_JobHead.log". This name will again be
  41. # appended with .1 and .2 for rotation.
  42. logging_filename: str
  43. logging_rotate_bytes: int
  44. logging_rotate_backup_count: int
  45. # The directory where the socket file will be created.
  46. socket_dir: str
  47. class SubprocessModule(abc.ABC):
  48. """
  49. A Dashboard Head Module that runs in a subprocess as a standalone aiohttp server.
  50. """
  51. def __init__(
  52. self,
  53. config: SubprocessModuleConfig,
  54. ):
  55. """
  56. Initialize current module when DashboardHead loading modules.
  57. :param dashboard_head: The DashboardHead instance.
  58. """
  59. self._config = config
  60. self._parent_process = multiprocessing.parent_process()
  61. # Lazy init
  62. self._gcs_client = None
  63. self._aiogrpc_gcs_channel = None
  64. self._parent_process_death_detection_task = None
  65. self._http_session = None
  66. async def _detect_parent_process_death(self):
  67. """
  68. Detect parent process liveness. Only returns when parent process is dead.
  69. """
  70. while True:
  71. if not self._parent_process.is_alive():
  72. logger.warning(
  73. f"Parent process {self._parent_process.pid} died. Exiting..."
  74. )
  75. return
  76. await asyncio.sleep(1)
  77. @staticmethod
  78. def is_minimal_module():
  79. """
  80. Currently all SubprocessModule classes should be non-minimal.
  81. We require this because SubprocessModuleHandle tracks aiohttp requests and
  82. responses. To ease this, we can define another SubprocessModuleMinimalHandle
  83. that doesn't track requests and responses, but still provides Queue interface
  84. and health check.
  85. TODO(ryw): If needed, create SubprocessModuleMinimalHandle.
  86. """
  87. return False
  88. async def run(self):
  89. """
  90. Start running the module.
  91. This method should be called first before the module starts receiving requests.
  92. """
  93. app = aiohttp.web.Application(
  94. client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE,
  95. )
  96. routes: list[aiohttp.web.RouteDef] = [
  97. aiohttp.web.get("/api/healthz", self._internal_module_health_check)
  98. ]
  99. handlers = inspect.getmembers(
  100. self,
  101. lambda x: (
  102. inspect.ismethod(x)
  103. and hasattr(x, "__route_method__")
  104. and hasattr(x, "__route_path__")
  105. ),
  106. )
  107. for _, handler in handlers:
  108. routes.append(
  109. aiohttp.web.route(
  110. handler.__route_method__,
  111. handler.__route_path__,
  112. handler,
  113. )
  114. )
  115. app.add_routes(routes)
  116. runner = aiohttp.web.AppRunner(app, access_log=None)
  117. await runner.setup()
  118. module_name = self.__class__.__name__
  119. if sys.platform == "win32":
  120. named_pipe_path = get_named_pipe_path(
  121. module_name, self._config.session_name
  122. )
  123. site = aiohttp.web.NamedPipeSite(runner, named_pipe_path)
  124. logger.info(f"Started aiohttp server over {named_pipe_path}.")
  125. else:
  126. socket_path = get_socket_path(self._config.socket_dir, module_name)
  127. site = aiohttp.web.UnixSite(runner, socket_path)
  128. logger.info(f"Started aiohttp server over {socket_path}.")
  129. await site.start()
  130. @property
  131. def gcs_client(self):
  132. if self._gcs_client is None:
  133. if not ray.experimental.internal_kv._internal_kv_initialized():
  134. gcs_client = GcsClient(
  135. address=self._config.gcs_address,
  136. cluster_id=self._config.cluster_id_hex,
  137. )
  138. ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
  139. self._gcs_client = ray.experimental.internal_kv.internal_kv_get_gcs_client()
  140. return self._gcs_client
  141. @property
  142. def aiogrpc_gcs_channel(self):
  143. if self._aiogrpc_gcs_channel is None:
  144. gcs_channel = GcsChannel(gcs_address=self._config.gcs_address, aio=True)
  145. gcs_channel.connect()
  146. self._aiogrpc_gcs_channel = gcs_channel.channel()
  147. return self._aiogrpc_gcs_channel
  148. @property
  149. def session_name(self):
  150. """
  151. Return the Ray session name. It's not related to the aiohttp session.
  152. """
  153. return self._config.session_name
  154. @property
  155. def temp_dir(self):
  156. return self._config.temp_dir
  157. @property
  158. def session_dir(self):
  159. return self._config.session_dir
  160. @property
  161. def log_dir(self):
  162. return self._config.log_dir
  163. @property
  164. def http_session(self):
  165. if self._http_session is None:
  166. self._http_session = aiohttp.ClientSession()
  167. return self._http_session
  168. @property
  169. def gcs_address(self):
  170. return self._config.gcs_address
  171. async def _internal_module_health_check(self, request):
  172. return aiohttp.web.Response(
  173. text="success",
  174. content_type="application/text",
  175. )
  176. async def run_module_inner(
  177. cls: type[SubprocessModule],
  178. config: SubprocessModuleConfig,
  179. incarnation: int,
  180. child_conn: multiprocessing.connection.Connection,
  181. ):
  182. module_name = cls.__name__
  183. logger.info(
  184. f"Starting module {module_name} with incarnation {incarnation} and config {config}"
  185. )
  186. try:
  187. module = cls(config)
  188. module._parent_process_death_detection_task = asyncio.create_task(
  189. module._detect_parent_process_death()
  190. )
  191. module._parent_process_death_detection_task.add_done_callback(
  192. lambda _: sys.exit()
  193. )
  194. await module.run()
  195. child_conn.send(None)
  196. child_conn.close()
  197. logger.info(f"Module {module_name} initialized, receiving messages...")
  198. except Exception as e:
  199. logger.exception(f"Error creating module {module_name}")
  200. raise e
  201. def run_module(
  202. cls: type[SubprocessModule],
  203. config: SubprocessModuleConfig,
  204. incarnation: int,
  205. child_conn: multiprocessing.connection.Connection,
  206. ):
  207. """
  208. Entrypoint for a subprocess module.
  209. """
  210. module_name = cls.__name__
  211. current_proctitle = ray._raylet.getproctitle()
  212. ray._raylet.setproctitle(
  213. f"ray-dashboard-{module_name}-{incarnation} ({current_proctitle})"
  214. )
  215. logging_filename = module_logging_filename(module_name, config.logging_filename)
  216. setup_component_logger(
  217. logging_level=config.logging_level,
  218. logging_format=config.logging_format,
  219. log_dir=config.log_dir,
  220. filename=logging_filename,
  221. max_bytes=config.logging_rotate_bytes,
  222. backup_count=config.logging_rotate_backup_count,
  223. )
  224. if config.logging_filename:
  225. stdout_filename = module_logging_filename(
  226. module_name, config.logging_filename, extension=".out"
  227. )
  228. stderr_filename = module_logging_filename(
  229. module_name, config.logging_filename, extension=".err"
  230. )
  231. logging_utils.redirect_stdout_stderr_if_needed(
  232. os.path.join(config.log_dir, stdout_filename),
  233. os.path.join(config.log_dir, stderr_filename),
  234. config.logging_rotate_bytes,
  235. config.logging_rotate_backup_count,
  236. )
  237. loop = asyncio.new_event_loop()
  238. task = loop.create_task(
  239. run_module_inner(
  240. cls,
  241. config,
  242. incarnation,
  243. child_conn,
  244. )
  245. )
  246. # TODO: do graceful shutdown.
  247. # 1. define a stop token.
  248. # 2. join the loop to wait for all pending tasks to finish, up until a timeout.
  249. # 3. close the loop and exit.
  250. def sigterm_handler(signum, frame):
  251. logger.warning(f"Exiting with signal {signum} immediately...")
  252. sys.exit(signum)
  253. ray._private.utils.set_sigterm_handler(sigterm_handler)
  254. loop.run_until_complete(task)
  255. loop.run_forever()