http_server_agent.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. import asyncio
  2. import logging
  3. import random
  4. from typing import List, Optional
  5. from packaging.version import Version
  6. import ray.dashboard.optional_utils as dashboard_optional_utils
  7. from ray._common.network_utils import build_address, is_localhost
  8. from ray._common.utils import get_or_create_event_loop
  9. from ray._private.authentication.http_token_authentication import (
  10. get_token_auth_middleware,
  11. )
  12. from ray.dashboard.optional_deps import aiohttp, aiohttp_cors, hdrs
  13. logger = logging.getLogger(__name__)
  14. routes = dashboard_optional_utils.DashboardAgentRouteTable
  15. # Health check endpoints should remain public (no auth required)
  16. PUBLIC_EXACT_PATHS = [
  17. "/api/healthz",
  18. "/api/local_raylet_healthz",
  19. ]
  20. class HttpServerAgent:
  21. def __init__(self, ip: str, listen_port: int) -> None:
  22. self.ip = ip
  23. self.listen_port = listen_port
  24. self.http_host = None
  25. self.http_port = None
  26. self.http_session = None
  27. self.runner = None
  28. async def _start_site_with_retry(
  29. self, max_retries: int = 5, base_delay: float = 0.1
  30. ) -> aiohttp.web.TCPSite:
  31. """Start the TCP site with retry logic and exponential backoff.
  32. Args:
  33. max_retries: Maximum number of retry attempts
  34. base_delay: Base delay in seconds for exponential backoff
  35. Returns:
  36. The started site object
  37. Raises:
  38. OSError: If all retry attempts fail
  39. """
  40. last_exception: Optional[OSError] = None
  41. for attempt in range(max_retries + 1): # +1 for initial attempt
  42. try:
  43. site = aiohttp.web.TCPSite(
  44. self.runner,
  45. self.ip,
  46. self.listen_port,
  47. )
  48. await site.start()
  49. if not is_localhost(self.ip):
  50. local_site = aiohttp.web.TCPSite(
  51. self.runner,
  52. "127.0.0.1",
  53. self.listen_port,
  54. )
  55. await local_site.start()
  56. if attempt > 0:
  57. logger.info(
  58. f"Successfully started agent on port {self.listen_port} "
  59. f"after {attempt} retry attempts"
  60. )
  61. return site
  62. except OSError as e:
  63. last_exception = e
  64. if attempt < max_retries:
  65. # Calculate exponential backoff with jitter
  66. delay = base_delay * (2**attempt) + random.uniform(0, 0.1)
  67. logger.warning(
  68. f"Failed to bind to port {self.listen_port} (attempt {attempt + 1}/"
  69. f"{max_retries + 1}). Retrying in {delay:.2f}s. Error: {e}"
  70. )
  71. await asyncio.sleep(delay)
  72. else:
  73. logger.exception(
  74. f"Agent port #{self.listen_port} failed to bind after "
  75. f"{max_retries + 1} attempts."
  76. )
  77. break
  78. # If we get here, all retries failed
  79. raise last_exception
  80. async def start(self, modules: List) -> None:
  81. # Create a http session for all modules.
  82. # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
  83. if Version(aiohttp.__version__) < Version("4.0.0"):
  84. self.http_session = aiohttp.ClientSession(loop=get_or_create_event_loop())
  85. else:
  86. self.http_session = aiohttp.ClientSession()
  87. # Bind routes for every module so that each module
  88. # can use decorator-style routes.
  89. for c in modules:
  90. dashboard_optional_utils.DashboardAgentRouteTable.bind(c)
  91. app = aiohttp.web.Application(
  92. middlewares=[
  93. get_token_auth_middleware(aiohttp, PUBLIC_EXACT_PATHS),
  94. # Block all browser requests - agent is only accessed internally
  95. dashboard_optional_utils.get_browser_request_middleware(aiohttp),
  96. ]
  97. )
  98. app.add_routes(routes=routes.bound_routes())
  99. # Enable CORS on all routes.
  100. cors = aiohttp_cors.setup(
  101. app,
  102. defaults={
  103. "*": aiohttp_cors.ResourceOptions(
  104. allow_credentials=True,
  105. expose_headers="*",
  106. allow_methods="*",
  107. allow_headers=("Content-Type", "X-Header"),
  108. )
  109. },
  110. )
  111. for route in list(app.router.routes()):
  112. cors.add(route)
  113. self.runner = aiohttp.web.AppRunner(app)
  114. await self.runner.setup()
  115. # Start the site with retry logic
  116. site = await self._start_site_with_retry()
  117. self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
  118. logger.info(
  119. "Dashboard agent http address: %s",
  120. build_address(self.http_host, self.http_port),
  121. )
  122. # Dump registered http routes.
  123. dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
  124. for r in dump_routes:
  125. logger.info(r)
  126. logger.info("Registered %s routes.", len(dump_routes))
  127. async def cleanup(self) -> None:
  128. # Wait for finish signal.
  129. await self.runner.cleanup()
  130. await self.http_session.close()