| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521 |
- import argparse
- import asyncio
- import json
- import logging
- import os
- import signal
- import sys
- import ray
- import ray._private.ray_constants as ray_constants
- import ray.dashboard.consts as dashboard_consts
- import ray.dashboard.utils as dashboard_utils
- from ray._common.network_utils import build_address, is_localhost
- from ray._common.utils import get_or_create_event_loop
- from ray._private import logging_utils
- from ray._private.process_watcher import create_check_raylet_task
- from ray._private.ray_constants import AGENT_GRPC_MAX_MESSAGE_LENGTH
- from ray._private.ray_logging import setup_component_logger
- from ray._raylet import (
- DASHBOARD_AGENT_LISTEN_PORT_NAME,
- METRICS_AGENT_PORT_NAME,
- METRICS_EXPORT_PORT_NAME,
- GcsClient,
- persist_port,
- )
- logger = logging.getLogger(__name__)
- class DashboardAgent:
- def __init__(
- self,
- node_ip_address,
- grpc_port,
- gcs_address,
- cluster_id_hex,
- minimal,
- metrics_export_port=None,
- node_manager_port=None,
- events_export_addr=None,
- listen_port=ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT,
- disable_metrics_collection: bool = False,
- is_head: bool = False,
- *, # the following are required kwargs
- object_store_name: str,
- raylet_name: str,
- log_dir: str,
- temp_dir: str,
- session_dir: str,
- session_name: str,
- ):
- """Initialize the DashboardAgent object."""
- # Public attributes are accessible for all agent modules.
- assert node_ip_address is not None
- self.ip = node_ip_address
- self.minimal = minimal
- assert gcs_address is not None
- self.gcs_address = gcs_address
- self.cluster_id_hex = cluster_id_hex
- self.temp_dir = temp_dir
- self.session_dir = session_dir
- self.log_dir = log_dir
- self.grpc_port = grpc_port
- self.metrics_export_port = metrics_export_port
- self.node_manager_port = node_manager_port
- self.events_export_addr = events_export_addr
- self.listen_port = listen_port
- self.object_store_name = object_store_name
- self.raylet_name = raylet_name
- self.node_id = os.environ["RAY_NODE_ID"]
- self.metrics_collection_disabled = disable_metrics_collection
- self.session_name = session_name
- # grpc server is None in mininal.
- self.server = None
- # http_server is None in minimal.
- self.http_server = None
- # Used by the agent and sub-modules.
- self.gcs_client = GcsClient(
- address=self.gcs_address,
- cluster_id=self.cluster_id_hex,
- )
- self.is_head = is_head
- if not self.minimal:
- self._init_non_minimal()
- else:
- # Write -1 to indicate the service is not in use.
- persist_port(
- self.session_dir,
- self.node_id,
- METRICS_AGENT_PORT_NAME,
- -1,
- )
- # This metric export port is run by reporter module
- # which is not included in minimal mode.
- persist_port(
- self.session_dir,
- self.node_id,
- METRICS_EXPORT_PORT_NAME,
- -1,
- )
- def _init_non_minimal(self):
- from grpc import aio as aiogrpc
- from ray._private.authentication.authentication_utils import (
- is_token_auth_enabled,
- )
- from ray._private.authentication.grpc_authentication_server_interceptor import (
- AsyncAuthenticationServerInterceptor,
- )
- from ray._private.tls_utils import add_port_to_grpc_server
- from ray.dashboard.http_server_agent import HttpServerAgent
- # We would want to suppress deprecating warnings from aiogrpc library
- # with the usage of asyncio.get_event_loop() in python version >=3.10
- # This could be removed once https://github.com/grpc/grpc/issues/32526
- # is released, and we used higher versions of grpcio that that.
- if sys.version_info.major >= 3 and sys.version_info.minor >= 10:
- import warnings
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", category=DeprecationWarning)
- aiogrpc.init_grpc_aio()
- else:
- aiogrpc.init_grpc_aio()
- # Add authentication interceptor if token auth is enabled
- interceptors = []
- if is_token_auth_enabled():
- interceptors.append(AsyncAuthenticationServerInterceptor())
- self.server = aiogrpc.server(
- interceptors=interceptors,
- options=(
- ("grpc.so_reuseport", 0),
- (
- "grpc.max_send_message_length",
- AGENT_GRPC_MAX_MESSAGE_LENGTH,
- ), # noqa
- (
- "grpc.max_receive_message_length",
- AGENT_GRPC_MAX_MESSAGE_LENGTH,
- ),
- ), # noqa
- )
- # grpc_port can be 0 for dynamic port assignment. get the actual bound port.
- self.grpc_port = add_port_to_grpc_server(
- self.server, build_address(self.ip, self.grpc_port)
- )
- if not is_localhost(self.ip):
- add_port_to_grpc_server(self.server, f"127.0.0.1:{self.grpc_port}")
- persist_port(
- self.session_dir,
- self.node_id,
- METRICS_AGENT_PORT_NAME,
- self.grpc_port,
- )
- logger.info(
- "Dashboard agent grpc address: %s",
- build_address(self.ip, self.grpc_port),
- )
- # If the agent is not minimal it should start the http server
- # to communicate with the dashboard in a head node.
- # Http server is not started in the minimal version because
- # it requires additional dependencies that are not
- # included in the minimal ray package.
- self.http_server = HttpServerAgent(self.ip, self.listen_port)
- def _load_modules(self):
- """Load dashboard agent modules."""
- modules = []
- agent_cls_list = dashboard_utils.get_all_modules(
- dashboard_utils.DashboardAgentModule
- )
- for cls in agent_cls_list:
- logger.info(
- "Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls
- )
- c = cls(self)
- modules.append(c)
- logger.info("Loaded %d modules.", len(modules))
- return modules
- @property
- def http_session(self):
- assert (
- self.http_server
- ), "Accessing unsupported API (HttpServerAgent) in a minimal ray."
- return self.http_server.http_session
- def get_node_id(self) -> str:
- return self.node_id
- async def run(self):
- # Start a grpc asyncio server.
- if self.server:
- await self.server.start()
- modules = self._load_modules()
- launch_http_server = True
- if self.http_server:
- try:
- await self.http_server.start(modules)
- # listen_port can be 0 for dynamic port assignment. get the actual bound port.
- self.listen_port = self.http_server.http_port
- except Exception as e:
- # TODO(kevin85421): We should fail the agent if the HTTP server
- # fails to start to avoid hiding the root cause. However,
- # agent processes are not cleaned up correctly after some tests
- # finish. If we fail the agent, the CI will always fail until
- # we fix the leak.
- logger.exception(
- f"Failed to start HTTP server with exception: {e}. "
- "The agent will stay alive but the HTTP service will be disabled.",
- )
- launch_http_server = False
- # If the HTTP server fails to start or is not launched, we should
- # persist -1 to indicate that the service is not available.
- persist_port(
- self.session_dir,
- self.node_id,
- DASHBOARD_AGENT_LISTEN_PORT_NAME,
- self.listen_port if self.http_server and launch_http_server else -1,
- )
- if launch_http_server:
- # Writes agent address to kv.
- # DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX: <node_id> -> (ip, http_port, grpc_port)
- # DASHBOARD_AGENT_ADDR_IP_PREFIX: <ip> -> (node_id, http_port, grpc_port)
- # -1 should indicate that http server is not started.
- http_port = -1 if not self.http_server else self.http_server.http_port
- grpc_port = -1 if not self.server else self.grpc_port
- put_by_node_id = self.gcs_client.async_internal_kv_put(
- f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{self.node_id}".encode(),
- json.dumps([self.ip, http_port, grpc_port]).encode(),
- True,
- namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
- )
- put_by_ip = self.gcs_client.async_internal_kv_put(
- f"{dashboard_consts.DASHBOARD_AGENT_ADDR_IP_PREFIX}{self.ip}".encode(),
- json.dumps([self.node_id, http_port, grpc_port]).encode(),
- True,
- namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
- )
- await asyncio.gather(put_by_node_id, put_by_ip)
- tasks = [m.run(self.server) for m in modules]
- if sys.platform not in ["win32", "cygwin"]:
- def callback(msg):
- logger.info(
- f"Terminated Raylet: ip={self.ip}, node_id={self.node_id}. {msg}"
- )
- check_parent_task = create_check_raylet_task(
- self.log_dir, self.gcs_client, callback, loop
- )
- tasks.append(check_parent_task)
- if self.server:
- tasks.append(self.server.wait_for_termination())
- else:
- async def wait_forever():
- while True:
- await asyncio.sleep(3600)
- tasks.append(wait_forever())
- await asyncio.gather(*tasks)
- if self.http_server:
- await self.http_server.cleanup()
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description="Dashboard agent.")
- parser.add_argument(
- "--node-id",
- required=True,
- type=str,
- help="the unique ID of this node.",
- )
- parser.add_argument(
- "--node-ip-address",
- required=True,
- type=str,
- help="the IP address of this node.",
- )
- parser.add_argument(
- "--gcs-address", required=True, type=str, help="The address (ip:port) of GCS."
- )
- parser.add_argument(
- "--cluster-id-hex",
- required=True,
- type=str,
- help="The cluster id in hex.",
- )
- parser.add_argument(
- "--metrics-export-port",
- required=True,
- type=int,
- help="The port to expose metrics through Prometheus.",
- )
- parser.add_argument(
- "--grpc-port",
- required=True,
- type=int,
- help="The port on which the dashboard agent will receive GRPCs.",
- )
- parser.add_argument(
- "--node-manager-port",
- required=True,
- type=int,
- help="The port to use for starting the node manager",
- )
- parser.add_argument(
- "--object-store-name",
- required=True,
- type=str,
- default=None,
- help="The socket name of the plasma store",
- )
- parser.add_argument(
- "--listen-port",
- required=False,
- type=int,
- default=ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT,
- help="Port for HTTP server to listen on",
- )
- parser.add_argument(
- "--raylet-name",
- required=True,
- type=str,
- default=None,
- help="The socket path of the raylet process",
- )
- parser.add_argument(
- "--logging-level",
- required=False,
- type=lambda s: logging.getLevelName(s.upper()),
- default=ray_constants.LOGGER_LEVEL,
- choices=ray_constants.LOGGER_LEVEL_CHOICES,
- help=ray_constants.LOGGER_LEVEL_HELP,
- )
- parser.add_argument(
- "--logging-format",
- required=False,
- type=str,
- default=ray_constants.LOGGER_FORMAT,
- help=ray_constants.LOGGER_FORMAT_HELP,
- )
- parser.add_argument(
- "--logging-filename",
- required=False,
- type=str,
- default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
- help="Specify the name of log file, "
- 'log to stdout if set empty, default is "{}".'.format(
- dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME
- ),
- )
- parser.add_argument(
- "--logging-rotate-bytes",
- required=True,
- type=int,
- help="Specify the max bytes for rotating log file.",
- )
- parser.add_argument(
- "--logging-rotate-backup-count",
- required=True,
- type=int,
- help="Specify the backup count of rotated log file.",
- )
- parser.add_argument(
- "--log-dir",
- required=True,
- type=str,
- default=None,
- help="Specify the path of log directory.",
- )
- parser.add_argument(
- "--temp-dir",
- required=True,
- type=str,
- default=None,
- help="Specify the path of the temporary directory use by Ray process.",
- )
- parser.add_argument(
- "--session-dir",
- required=True,
- type=str,
- default=None,
- help="Specify the path of this session.",
- )
- parser.add_argument(
- "--minimal",
- action="store_true",
- help=(
- "Minimal agent only contains a subset of features that don't "
- "require additional dependencies installed when ray is installed "
- "by `pip install 'ray[default]'`."
- ),
- )
- parser.add_argument(
- "--disable-metrics-collection",
- action="store_true",
- help=("If this arg is set, metrics report won't be enabled from the agent."),
- )
- parser.add_argument(
- "--head",
- action="store_true",
- help="Whether this node is the head node.",
- )
- parser.add_argument(
- "--session-name",
- required=False,
- type=str,
- default=None,
- help="The current Ray session name.",
- )
- parser.add_argument(
- "--stdout-filepath",
- required=False,
- type=str,
- default="",
- help="The filepath to dump dashboard agent stdout.",
- )
- parser.add_argument(
- "--stderr-filepath",
- required=False,
- type=str,
- default="",
- help="The filepath to dump dashboard agent stderr.",
- )
- args = parser.parse_args()
- try:
- # Disable log rotation for windows platform.
- logging_rotation_bytes = (
- args.logging_rotate_bytes if sys.platform != "win32" else 0
- )
- logging_rotation_backup_count = (
- args.logging_rotate_backup_count if sys.platform != "win32" else 1
- )
- logger = setup_component_logger(
- logging_level=args.logging_level,
- logging_format=args.logging_format,
- log_dir=args.log_dir,
- filename=args.logging_filename,
- max_bytes=logging_rotation_bytes,
- backup_count=logging_rotation_backup_count,
- )
- # Setup stdout/stderr redirect files if redirection enabled.
- logging_utils.redirect_stdout_stderr_if_needed(
- args.stdout_filepath,
- args.stderr_filepath,
- logging_rotation_bytes,
- logging_rotation_backup_count,
- )
- # Initialize event loop, see Dashboard init code for caveat
- # w.r.t grpc server init in the DashboardAgent initializer.
- loop = get_or_create_event_loop()
- agent = DashboardAgent(
- args.node_ip_address,
- args.grpc_port,
- args.gcs_address,
- args.cluster_id_hex,
- args.minimal,
- temp_dir=args.temp_dir,
- session_dir=args.session_dir,
- log_dir=args.log_dir,
- metrics_export_port=args.metrics_export_port,
- node_manager_port=args.node_manager_port,
- listen_port=args.listen_port,
- object_store_name=args.object_store_name,
- raylet_name=args.raylet_name,
- disable_metrics_collection=args.disable_metrics_collection,
- is_head=args.head,
- session_name=args.session_name,
- )
- ray._raylet.setproctitle(ray_constants.AGENT_PROCESS_TYPE_DASHBOARD_AGENT)
- def sigterm_handler():
- logger.warning("Exiting with SIGTERM immediately...")
- # Exit code 0 will be considered as an expected shutdown
- os._exit(signal.SIGTERM)
- if sys.platform != "win32":
- # TODO(rickyyx): we currently do not have any logic for actual
- # graceful termination in the agent. Most of the underlying
- # async tasks run by the agent head doesn't handle CancelledError.
- # So a truly graceful shutdown is not trivial w/o much refactoring.
- # Re-open the issue: https://github.com/ray-project/ray/issues/25518
- # if a truly graceful shutdown is required.
- loop.add_signal_handler(signal.SIGTERM, sigterm_handler)
- loop.run_until_complete(agent.run())
- except Exception:
- logger.exception("Agent is working abnormally. It will exit immediately.")
- exit(1)
|