| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- import abc
- import asyncio
- import inspect
- import logging
- import multiprocessing
- import multiprocessing.connection
- import os
- import sys
- from dataclasses import dataclass
- import aiohttp
- import ray
- from ray import ray_constants
- from ray._private import logging_utils
- from ray._private.gcs_utils import GcsChannel
- from ray._private.ray_logging import setup_component_logger
- from ray._raylet import GcsClient
- from ray.dashboard.subprocesses.utils import (
- get_named_pipe_path,
- get_socket_path,
- module_logging_filename,
- )
- logger = logging.getLogger(__name__)
- @dataclass
- class SubprocessModuleConfig:
- """
- Configuration for a SubprocessModule.
- Pickleable.
- """
- cluster_id_hex: str
- gcs_address: str
- session_name: str
- temp_dir: str
- session_dir: str
- # Logger configs. Will be set up in subprocess entrypoint `run_module`.
- logging_level: str
- logging_format: str
- log_dir: str
- # Name of the "base" log file. Its stem is appended with the Module.__name__.
- # e.g. when logging_filename = "dashboard.log", and Module is JobHead,
- # we will set up logger with name "dashboard_JobHead.log". This name will again be
- # appended with .1 and .2 for rotation.
- logging_filename: str
- logging_rotate_bytes: int
- logging_rotate_backup_count: int
- # The directory where the socket file will be created.
- socket_dir: str
- class SubprocessModule(abc.ABC):
- """
- A Dashboard Head Module that runs in a subprocess as a standalone aiohttp server.
- """
- def __init__(
- self,
- config: SubprocessModuleConfig,
- ):
- """
- Initialize current module when DashboardHead loading modules.
- :param dashboard_head: The DashboardHead instance.
- """
- self._config = config
- self._parent_process = multiprocessing.parent_process()
- # Lazy init
- self._gcs_client = None
- self._aiogrpc_gcs_channel = None
- self._parent_process_death_detection_task = None
- self._http_session = None
- async def _detect_parent_process_death(self):
- """
- Detect parent process liveness. Only returns when parent process is dead.
- """
- while True:
- if not self._parent_process.is_alive():
- logger.warning(
- f"Parent process {self._parent_process.pid} died. Exiting..."
- )
- return
- await asyncio.sleep(1)
- @staticmethod
- def is_minimal_module():
- """
- Currently all SubprocessModule classes should be non-minimal.
- We require this because SubprocessModuleHandle tracks aiohttp requests and
- responses. To ease this, we can define another SubprocessModuleMinimalHandle
- that doesn't track requests and responses, but still provides Queue interface
- and health check.
- TODO(ryw): If needed, create SubprocessModuleMinimalHandle.
- """
- return False
- async def run(self):
- """
- Start running the module.
- This method should be called first before the module starts receiving requests.
- """
- app = aiohttp.web.Application(
- client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE,
- )
- routes: list[aiohttp.web.RouteDef] = [
- aiohttp.web.get("/api/healthz", self._internal_module_health_check)
- ]
- handlers = inspect.getmembers(
- self,
- lambda x: (
- inspect.ismethod(x)
- and hasattr(x, "__route_method__")
- and hasattr(x, "__route_path__")
- ),
- )
- for _, handler in handlers:
- routes.append(
- aiohttp.web.route(
- handler.__route_method__,
- handler.__route_path__,
- handler,
- )
- )
- app.add_routes(routes)
- runner = aiohttp.web.AppRunner(app, access_log=None)
- await runner.setup()
- module_name = self.__class__.__name__
- if sys.platform == "win32":
- named_pipe_path = get_named_pipe_path(
- module_name, self._config.session_name
- )
- site = aiohttp.web.NamedPipeSite(runner, named_pipe_path)
- logger.info(f"Started aiohttp server over {named_pipe_path}.")
- else:
- socket_path = get_socket_path(self._config.socket_dir, module_name)
- site = aiohttp.web.UnixSite(runner, socket_path)
- logger.info(f"Started aiohttp server over {socket_path}.")
- await site.start()
- @property
- def gcs_client(self):
- if self._gcs_client is None:
- if not ray.experimental.internal_kv._internal_kv_initialized():
- gcs_client = GcsClient(
- address=self._config.gcs_address,
- cluster_id=self._config.cluster_id_hex,
- )
- ray.experimental.internal_kv._initialize_internal_kv(gcs_client)
- self._gcs_client = ray.experimental.internal_kv.internal_kv_get_gcs_client()
- return self._gcs_client
- @property
- def aiogrpc_gcs_channel(self):
- if self._aiogrpc_gcs_channel is None:
- gcs_channel = GcsChannel(gcs_address=self._config.gcs_address, aio=True)
- gcs_channel.connect()
- self._aiogrpc_gcs_channel = gcs_channel.channel()
- return self._aiogrpc_gcs_channel
- @property
- def session_name(self):
- """
- Return the Ray session name. It's not related to the aiohttp session.
- """
- return self._config.session_name
- @property
- def temp_dir(self):
- return self._config.temp_dir
- @property
- def session_dir(self):
- return self._config.session_dir
- @property
- def log_dir(self):
- return self._config.log_dir
- @property
- def http_session(self):
- if self._http_session is None:
- self._http_session = aiohttp.ClientSession()
- return self._http_session
- @property
- def gcs_address(self):
- return self._config.gcs_address
- async def _internal_module_health_check(self, request):
- return aiohttp.web.Response(
- text="success",
- content_type="application/text",
- )
- async def run_module_inner(
- cls: type[SubprocessModule],
- config: SubprocessModuleConfig,
- incarnation: int,
- child_conn: multiprocessing.connection.Connection,
- ):
- module_name = cls.__name__
- logger.info(
- f"Starting module {module_name} with incarnation {incarnation} and config {config}"
- )
- try:
- module = cls(config)
- module._parent_process_death_detection_task = asyncio.create_task(
- module._detect_parent_process_death()
- )
- module._parent_process_death_detection_task.add_done_callback(
- lambda _: sys.exit()
- )
- await module.run()
- child_conn.send(None)
- child_conn.close()
- logger.info(f"Module {module_name} initialized, receiving messages...")
- except Exception as e:
- logger.exception(f"Error creating module {module_name}")
- raise e
- def run_module(
- cls: type[SubprocessModule],
- config: SubprocessModuleConfig,
- incarnation: int,
- child_conn: multiprocessing.connection.Connection,
- ):
- """
- Entrypoint for a subprocess module.
- """
- module_name = cls.__name__
- current_proctitle = ray._raylet.getproctitle()
- ray._raylet.setproctitle(
- f"ray-dashboard-{module_name}-{incarnation} ({current_proctitle})"
- )
- logging_filename = module_logging_filename(module_name, config.logging_filename)
- setup_component_logger(
- logging_level=config.logging_level,
- logging_format=config.logging_format,
- log_dir=config.log_dir,
- filename=logging_filename,
- max_bytes=config.logging_rotate_bytes,
- backup_count=config.logging_rotate_backup_count,
- )
- if config.logging_filename:
- stdout_filename = module_logging_filename(
- module_name, config.logging_filename, extension=".out"
- )
- stderr_filename = module_logging_filename(
- module_name, config.logging_filename, extension=".err"
- )
- logging_utils.redirect_stdout_stderr_if_needed(
- os.path.join(config.log_dir, stdout_filename),
- os.path.join(config.log_dir, stderr_filename),
- config.logging_rotate_bytes,
- config.logging_rotate_backup_count,
- )
- loop = asyncio.new_event_loop()
- task = loop.create_task(
- run_module_inner(
- cls,
- config,
- incarnation,
- child_conn,
- )
- )
- # TODO: do graceful shutdown.
- # 1. define a stop token.
- # 2. join the loop to wait for all pending tasks to finish, up until a timeout.
- # 3. close the loop and exit.
- def sigterm_handler(signum, frame):
- logger.warning(f"Exiting with signal {signum} immediately...")
- sys.exit(signum)
- ray._private.utils.set_sigterm_handler(sigterm_handler)
- loop.run_until_complete(task)
- loop.run_forever()
|