agent.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import argparse
  2. import asyncio
  3. import json
  4. import logging
  5. import os
  6. import signal
  7. import sys
  8. import ray
  9. import ray._private.ray_constants as ray_constants
  10. import ray.dashboard.consts as dashboard_consts
  11. import ray.dashboard.utils as dashboard_utils
  12. from ray._common.network_utils import build_address, is_localhost
  13. from ray._common.utils import get_or_create_event_loop
  14. from ray._private import logging_utils
  15. from ray._private.process_watcher import create_check_raylet_task
  16. from ray._private.ray_constants import AGENT_GRPC_MAX_MESSAGE_LENGTH
  17. from ray._private.ray_logging import setup_component_logger
  18. from ray._raylet import (
  19. DASHBOARD_AGENT_LISTEN_PORT_NAME,
  20. METRICS_AGENT_PORT_NAME,
  21. METRICS_EXPORT_PORT_NAME,
  22. GcsClient,
  23. persist_port,
  24. )
  25. logger = logging.getLogger(__name__)
  26. class DashboardAgent:
  27. def __init__(
  28. self,
  29. node_ip_address,
  30. grpc_port,
  31. gcs_address,
  32. cluster_id_hex,
  33. minimal,
  34. metrics_export_port=None,
  35. node_manager_port=None,
  36. events_export_addr=None,
  37. listen_port=ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT,
  38. disable_metrics_collection: bool = False,
  39. is_head: bool = False,
  40. *, # the following are required kwargs
  41. object_store_name: str,
  42. raylet_name: str,
  43. log_dir: str,
  44. temp_dir: str,
  45. session_dir: str,
  46. session_name: str,
  47. ):
  48. """Initialize the DashboardAgent object."""
  49. # Public attributes are accessible for all agent modules.
  50. assert node_ip_address is not None
  51. self.ip = node_ip_address
  52. self.minimal = minimal
  53. assert gcs_address is not None
  54. self.gcs_address = gcs_address
  55. self.cluster_id_hex = cluster_id_hex
  56. self.temp_dir = temp_dir
  57. self.session_dir = session_dir
  58. self.log_dir = log_dir
  59. self.grpc_port = grpc_port
  60. self.metrics_export_port = metrics_export_port
  61. self.node_manager_port = node_manager_port
  62. self.events_export_addr = events_export_addr
  63. self.listen_port = listen_port
  64. self.object_store_name = object_store_name
  65. self.raylet_name = raylet_name
  66. self.node_id = os.environ["RAY_NODE_ID"]
  67. self.metrics_collection_disabled = disable_metrics_collection
  68. self.session_name = session_name
  69. # grpc server is None in mininal.
  70. self.server = None
  71. # http_server is None in minimal.
  72. self.http_server = None
  73. # Used by the agent and sub-modules.
  74. self.gcs_client = GcsClient(
  75. address=self.gcs_address,
  76. cluster_id=self.cluster_id_hex,
  77. )
  78. self.is_head = is_head
  79. if not self.minimal:
  80. self._init_non_minimal()
  81. else:
  82. # Write -1 to indicate the service is not in use.
  83. persist_port(
  84. self.session_dir,
  85. self.node_id,
  86. METRICS_AGENT_PORT_NAME,
  87. -1,
  88. )
  89. # This metric export port is run by reporter module
  90. # which is not included in minimal mode.
  91. persist_port(
  92. self.session_dir,
  93. self.node_id,
  94. METRICS_EXPORT_PORT_NAME,
  95. -1,
  96. )
  97. def _init_non_minimal(self):
  98. from grpc import aio as aiogrpc
  99. from ray._private.authentication.authentication_utils import (
  100. is_token_auth_enabled,
  101. )
  102. from ray._private.authentication.grpc_authentication_server_interceptor import (
  103. AsyncAuthenticationServerInterceptor,
  104. )
  105. from ray._private.tls_utils import add_port_to_grpc_server
  106. from ray.dashboard.http_server_agent import HttpServerAgent
  107. # We would want to suppress deprecating warnings from aiogrpc library
  108. # with the usage of asyncio.get_event_loop() in python version >=3.10
  109. # This could be removed once https://github.com/grpc/grpc/issues/32526
  110. # is released, and we used higher versions of grpcio that that.
  111. if sys.version_info.major >= 3 and sys.version_info.minor >= 10:
  112. import warnings
  113. with warnings.catch_warnings():
  114. warnings.simplefilter("ignore", category=DeprecationWarning)
  115. aiogrpc.init_grpc_aio()
  116. else:
  117. aiogrpc.init_grpc_aio()
  118. # Add authentication interceptor if token auth is enabled
  119. interceptors = []
  120. if is_token_auth_enabled():
  121. interceptors.append(AsyncAuthenticationServerInterceptor())
  122. self.server = aiogrpc.server(
  123. interceptors=interceptors,
  124. options=(
  125. ("grpc.so_reuseport", 0),
  126. (
  127. "grpc.max_send_message_length",
  128. AGENT_GRPC_MAX_MESSAGE_LENGTH,
  129. ), # noqa
  130. (
  131. "grpc.max_receive_message_length",
  132. AGENT_GRPC_MAX_MESSAGE_LENGTH,
  133. ),
  134. ), # noqa
  135. )
  136. # grpc_port can be 0 for dynamic port assignment. get the actual bound port.
  137. self.grpc_port = add_port_to_grpc_server(
  138. self.server, build_address(self.ip, self.grpc_port)
  139. )
  140. if not is_localhost(self.ip):
  141. add_port_to_grpc_server(self.server, f"127.0.0.1:{self.grpc_port}")
  142. persist_port(
  143. self.session_dir,
  144. self.node_id,
  145. METRICS_AGENT_PORT_NAME,
  146. self.grpc_port,
  147. )
  148. logger.info(
  149. "Dashboard agent grpc address: %s",
  150. build_address(self.ip, self.grpc_port),
  151. )
  152. # If the agent is not minimal it should start the http server
  153. # to communicate with the dashboard in a head node.
  154. # Http server is not started in the minimal version because
  155. # it requires additional dependencies that are not
  156. # included in the minimal ray package.
  157. self.http_server = HttpServerAgent(self.ip, self.listen_port)
  158. def _load_modules(self):
  159. """Load dashboard agent modules."""
  160. modules = []
  161. agent_cls_list = dashboard_utils.get_all_modules(
  162. dashboard_utils.DashboardAgentModule
  163. )
  164. for cls in agent_cls_list:
  165. logger.info(
  166. "Loading %s: %s", dashboard_utils.DashboardAgentModule.__name__, cls
  167. )
  168. c = cls(self)
  169. modules.append(c)
  170. logger.info("Loaded %d modules.", len(modules))
  171. return modules
  172. @property
  173. def http_session(self):
  174. assert (
  175. self.http_server
  176. ), "Accessing unsupported API (HttpServerAgent) in a minimal ray."
  177. return self.http_server.http_session
  178. def get_node_id(self) -> str:
  179. return self.node_id
  180. async def run(self):
  181. # Start a grpc asyncio server.
  182. if self.server:
  183. await self.server.start()
  184. modules = self._load_modules()
  185. launch_http_server = True
  186. if self.http_server:
  187. try:
  188. await self.http_server.start(modules)
  189. # listen_port can be 0 for dynamic port assignment. get the actual bound port.
  190. self.listen_port = self.http_server.http_port
  191. except Exception as e:
  192. # TODO(kevin85421): We should fail the agent if the HTTP server
  193. # fails to start to avoid hiding the root cause. However,
  194. # agent processes are not cleaned up correctly after some tests
  195. # finish. If we fail the agent, the CI will always fail until
  196. # we fix the leak.
  197. logger.exception(
  198. f"Failed to start HTTP server with exception: {e}. "
  199. "The agent will stay alive but the HTTP service will be disabled.",
  200. )
  201. launch_http_server = False
  202. # If the HTTP server fails to start or is not launched, we should
  203. # persist -1 to indicate that the service is not available.
  204. persist_port(
  205. self.session_dir,
  206. self.node_id,
  207. DASHBOARD_AGENT_LISTEN_PORT_NAME,
  208. self.listen_port if self.http_server and launch_http_server else -1,
  209. )
  210. if launch_http_server:
  211. # Writes agent address to kv.
  212. # DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX: <node_id> -> (ip, http_port, grpc_port)
  213. # DASHBOARD_AGENT_ADDR_IP_PREFIX: <ip> -> (node_id, http_port, grpc_port)
  214. # -1 should indicate that http server is not started.
  215. http_port = -1 if not self.http_server else self.http_server.http_port
  216. grpc_port = -1 if not self.server else self.grpc_port
  217. put_by_node_id = self.gcs_client.async_internal_kv_put(
  218. f"{dashboard_consts.DASHBOARD_AGENT_ADDR_NODE_ID_PREFIX}{self.node_id}".encode(),
  219. json.dumps([self.ip, http_port, grpc_port]).encode(),
  220. True,
  221. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  222. )
  223. put_by_ip = self.gcs_client.async_internal_kv_put(
  224. f"{dashboard_consts.DASHBOARD_AGENT_ADDR_IP_PREFIX}{self.ip}".encode(),
  225. json.dumps([self.node_id, http_port, grpc_port]).encode(),
  226. True,
  227. namespace=ray_constants.KV_NAMESPACE_DASHBOARD,
  228. )
  229. await asyncio.gather(put_by_node_id, put_by_ip)
  230. tasks = [m.run(self.server) for m in modules]
  231. if sys.platform not in ["win32", "cygwin"]:
  232. def callback(msg):
  233. logger.info(
  234. f"Terminated Raylet: ip={self.ip}, node_id={self.node_id}. {msg}"
  235. )
  236. check_parent_task = create_check_raylet_task(
  237. self.log_dir, self.gcs_client, callback, loop
  238. )
  239. tasks.append(check_parent_task)
  240. if self.server:
  241. tasks.append(self.server.wait_for_termination())
  242. else:
  243. async def wait_forever():
  244. while True:
  245. await asyncio.sleep(3600)
  246. tasks.append(wait_forever())
  247. await asyncio.gather(*tasks)
  248. if self.http_server:
  249. await self.http_server.cleanup()
  250. if __name__ == "__main__":
  251. parser = argparse.ArgumentParser(description="Dashboard agent.")
  252. parser.add_argument(
  253. "--node-id",
  254. required=True,
  255. type=str,
  256. help="the unique ID of this node.",
  257. )
  258. parser.add_argument(
  259. "--node-ip-address",
  260. required=True,
  261. type=str,
  262. help="the IP address of this node.",
  263. )
  264. parser.add_argument(
  265. "--gcs-address", required=True, type=str, help="The address (ip:port) of GCS."
  266. )
  267. parser.add_argument(
  268. "--cluster-id-hex",
  269. required=True,
  270. type=str,
  271. help="The cluster id in hex.",
  272. )
  273. parser.add_argument(
  274. "--metrics-export-port",
  275. required=True,
  276. type=int,
  277. help="The port to expose metrics through Prometheus.",
  278. )
  279. parser.add_argument(
  280. "--grpc-port",
  281. required=True,
  282. type=int,
  283. help="The port on which the dashboard agent will receive GRPCs.",
  284. )
  285. parser.add_argument(
  286. "--node-manager-port",
  287. required=True,
  288. type=int,
  289. help="The port to use for starting the node manager",
  290. )
  291. parser.add_argument(
  292. "--object-store-name",
  293. required=True,
  294. type=str,
  295. default=None,
  296. help="The socket name of the plasma store",
  297. )
  298. parser.add_argument(
  299. "--listen-port",
  300. required=False,
  301. type=int,
  302. default=ray_constants.DEFAULT_DASHBOARD_AGENT_LISTEN_PORT,
  303. help="Port for HTTP server to listen on",
  304. )
  305. parser.add_argument(
  306. "--raylet-name",
  307. required=True,
  308. type=str,
  309. default=None,
  310. help="The socket path of the raylet process",
  311. )
  312. parser.add_argument(
  313. "--logging-level",
  314. required=False,
  315. type=lambda s: logging.getLevelName(s.upper()),
  316. default=ray_constants.LOGGER_LEVEL,
  317. choices=ray_constants.LOGGER_LEVEL_CHOICES,
  318. help=ray_constants.LOGGER_LEVEL_HELP,
  319. )
  320. parser.add_argument(
  321. "--logging-format",
  322. required=False,
  323. type=str,
  324. default=ray_constants.LOGGER_FORMAT,
  325. help=ray_constants.LOGGER_FORMAT_HELP,
  326. )
  327. parser.add_argument(
  328. "--logging-filename",
  329. required=False,
  330. type=str,
  331. default=dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME,
  332. help="Specify the name of log file, "
  333. 'log to stdout if set empty, default is "{}".'.format(
  334. dashboard_consts.DASHBOARD_AGENT_LOG_FILENAME
  335. ),
  336. )
  337. parser.add_argument(
  338. "--logging-rotate-bytes",
  339. required=True,
  340. type=int,
  341. help="Specify the max bytes for rotating log file.",
  342. )
  343. parser.add_argument(
  344. "--logging-rotate-backup-count",
  345. required=True,
  346. type=int,
  347. help="Specify the backup count of rotated log file.",
  348. )
  349. parser.add_argument(
  350. "--log-dir",
  351. required=True,
  352. type=str,
  353. default=None,
  354. help="Specify the path of log directory.",
  355. )
  356. parser.add_argument(
  357. "--temp-dir",
  358. required=True,
  359. type=str,
  360. default=None,
  361. help="Specify the path of the temporary directory use by Ray process.",
  362. )
  363. parser.add_argument(
  364. "--session-dir",
  365. required=True,
  366. type=str,
  367. default=None,
  368. help="Specify the path of this session.",
  369. )
  370. parser.add_argument(
  371. "--minimal",
  372. action="store_true",
  373. help=(
  374. "Minimal agent only contains a subset of features that don't "
  375. "require additional dependencies installed when ray is installed "
  376. "by `pip install 'ray[default]'`."
  377. ),
  378. )
  379. parser.add_argument(
  380. "--disable-metrics-collection",
  381. action="store_true",
  382. help=("If this arg is set, metrics report won't be enabled from the agent."),
  383. )
  384. parser.add_argument(
  385. "--head",
  386. action="store_true",
  387. help="Whether this node is the head node.",
  388. )
  389. parser.add_argument(
  390. "--session-name",
  391. required=False,
  392. type=str,
  393. default=None,
  394. help="The current Ray session name.",
  395. )
  396. parser.add_argument(
  397. "--stdout-filepath",
  398. required=False,
  399. type=str,
  400. default="",
  401. help="The filepath to dump dashboard agent stdout.",
  402. )
  403. parser.add_argument(
  404. "--stderr-filepath",
  405. required=False,
  406. type=str,
  407. default="",
  408. help="The filepath to dump dashboard agent stderr.",
  409. )
  410. args = parser.parse_args()
  411. try:
  412. # Disable log rotation for windows platform.
  413. logging_rotation_bytes = (
  414. args.logging_rotate_bytes if sys.platform != "win32" else 0
  415. )
  416. logging_rotation_backup_count = (
  417. args.logging_rotate_backup_count if sys.platform != "win32" else 1
  418. )
  419. logger = setup_component_logger(
  420. logging_level=args.logging_level,
  421. logging_format=args.logging_format,
  422. log_dir=args.log_dir,
  423. filename=args.logging_filename,
  424. max_bytes=logging_rotation_bytes,
  425. backup_count=logging_rotation_backup_count,
  426. )
  427. # Setup stdout/stderr redirect files if redirection enabled.
  428. logging_utils.redirect_stdout_stderr_if_needed(
  429. args.stdout_filepath,
  430. args.stderr_filepath,
  431. logging_rotation_bytes,
  432. logging_rotation_backup_count,
  433. )
  434. # Initialize event loop, see Dashboard init code for caveat
  435. # w.r.t grpc server init in the DashboardAgent initializer.
  436. loop = get_or_create_event_loop()
  437. agent = DashboardAgent(
  438. args.node_ip_address,
  439. args.grpc_port,
  440. args.gcs_address,
  441. args.cluster_id_hex,
  442. args.minimal,
  443. temp_dir=args.temp_dir,
  444. session_dir=args.session_dir,
  445. log_dir=args.log_dir,
  446. metrics_export_port=args.metrics_export_port,
  447. node_manager_port=args.node_manager_port,
  448. listen_port=args.listen_port,
  449. object_store_name=args.object_store_name,
  450. raylet_name=args.raylet_name,
  451. disable_metrics_collection=args.disable_metrics_collection,
  452. is_head=args.head,
  453. session_name=args.session_name,
  454. )
  455. ray._raylet.setproctitle(ray_constants.AGENT_PROCESS_TYPE_DASHBOARD_AGENT)
  456. def sigterm_handler():
  457. logger.warning("Exiting with SIGTERM immediately...")
  458. # Exit code 0 will be considered as an expected shutdown
  459. os._exit(signal.SIGTERM)
  460. if sys.platform != "win32":
  461. # TODO(rickyyx): we currently do not have any logic for actual
  462. # graceful termination in the agent. Most of the underlying
  463. # async tasks run by the agent head doesn't handle CancelledError.
  464. # So a truly graceful shutdown is not trivial w/o much refactoring.
  465. # Re-open the issue: https://github.com/ray-project/ray/issues/25518
  466. # if a truly graceful shutdown is required.
  467. loop.add_signal_handler(signal.SIGTERM, sigterm_handler)
  468. loop.run_until_complete(agent.run())
  469. except Exception:
  470. logger.exception("Agent is working abnormally. It will exit immediately.")
  471. exit(1)