http_server_head.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import asyncio
  2. import errno
  3. import ipaddress
  4. import logging
  5. import os
  6. import pathlib
  7. import posixpath
  8. import sys
  9. import time
  10. from math import floor
  11. from typing import List
  12. from packaging.version import Version
  13. import ray
  14. import ray.dashboard.optional_utils as dashboard_optional_utils
  15. import ray.dashboard.timezone_utils as timezone_utils
  16. import ray.dashboard.utils as dashboard_utils
  17. from ray import ray_constants
  18. from ray._common.network_utils import build_address, parse_address
  19. from ray._common.usage.usage_lib import TagKey, record_extra_usage_tag
  20. from ray._common.utils import get_or_create_event_loop
  21. from ray._private.authentication import (
  22. authentication_constants,
  23. authentication_utils as auth_utils,
  24. )
  25. from ray._private.authentication.http_token_authentication import (
  26. get_token_auth_middleware,
  27. )
  28. from ray._raylet import get_authentication_mode
  29. from ray.dashboard.dashboard_metrics import DashboardPrometheusMetrics
  30. from ray.dashboard.head import DashboardHeadModule
  31. # All third-party dependencies that are not included in the minimal Ray
  32. # installation must be included in this file. This allows us to determine if
  33. # the agent has the necessary dependencies to be started.
  34. from ray.dashboard.optional_deps import aiohttp, hdrs
  35. from ray.dashboard.subprocesses.handle import SubprocessModuleHandle
  36. from ray.dashboard.subprocesses.routes import SubprocessRouteTable
  37. # Logger for this module. It should be configured at the entry point
  38. # into the program using Ray. Ray provides a default configuration at
  39. # entry/init points.
  40. logger = logging.getLogger(__name__)
  41. routes = dashboard_optional_utils.DashboardHeadRouteTable
  42. # Env var that enables follow_symlinks for serving UI static files.
  43. # This is an advanced setting that should only be used with special Ray installations
  44. # where the dashboard build files are symlinked to a different directory.
  45. # This is not recommended for most users and can pose a security risk.
  46. # Please reference the aiohttp docs here:
  47. # https://docs.aiohttp.org/en/stable/web_reference.html#aiohttp.web.UrlDispatcher.add_static
  48. ENV_VAR_FOLLOW_SYMLINKS = "RAY_DASHBOARD_BUILD_FOLLOW_SYMLINKS"
  49. FOLLOW_SYMLINKS_ENABLED = os.environ.get(ENV_VAR_FOLLOW_SYMLINKS) == "1"
  50. if FOLLOW_SYMLINKS_ENABLED:
  51. logger.warning(
  52. "Enabling RAY_DASHBOARD_BUILD_FOLLOW_SYMLINKS is not recommended as it "
  53. "allows symlinks to directories outside the dashboard build folder. "
  54. "You may accidentally expose files on your system outside of the "
  55. "build directory."
  56. )
  57. def setup_static_dir():
  58. build_dir = os.path.join(
  59. os.path.dirname(os.path.abspath(__file__)), "client", "build"
  60. )
  61. module_name = os.path.basename(os.path.dirname(__file__))
  62. if not os.path.isdir(build_dir):
  63. raise dashboard_utils.FrontendNotFoundError(
  64. errno.ENOENT,
  65. "Dashboard build directory not found. If installing "
  66. "from source, please follow the additional steps "
  67. "required to build the dashboard"
  68. f"(cd python/ray/{module_name}/client "
  69. "&& npm ci "
  70. "&& npm run build)",
  71. build_dir,
  72. )
  73. static_dir = os.path.join(build_dir, "static")
  74. routes.static("/static", static_dir, follow_symlinks=FOLLOW_SYMLINKS_ENABLED)
  75. return build_dir
  76. class HttpServerDashboardHead:
  77. def __init__(
  78. self,
  79. ip: str,
  80. http_host: str,
  81. http_port: int,
  82. http_port_retries: int,
  83. gcs_address: str,
  84. session_name: str,
  85. metrics: DashboardPrometheusMetrics,
  86. ):
  87. self.ip = ip
  88. self.http_host = http_host
  89. self.http_port = http_port
  90. self.http_port_retries = http_port_retries
  91. self.head_node_ip = parse_address(gcs_address)[0]
  92. self.metrics = metrics
  93. self._session_name = session_name
  94. # Below attirubtes are filled after `run` API is invoked.
  95. self.runner = None
  96. # Setup Dashboard Routes
  97. try:
  98. build_dir = setup_static_dir()
  99. logger.info("Setup static dir for dashboard: %s", build_dir)
  100. except dashboard_utils.FrontendNotFoundError as ex:
  101. # Not to raise FrontendNotFoundError due to NPM incompatibilities
  102. # with Windows.
  103. # Please refer to ci.sh::build_dashboard_front_end()
  104. if sys.platform in ["win32", "cygwin"]:
  105. logger.warning(ex)
  106. else:
  107. raise ex
  108. dashboard_optional_utils.DashboardHeadRouteTable.bind(self)
  109. # Create a http session for all modules.
  110. # aiohttp<4.0.0 uses a 'loop' variable, aiohttp>=4.0.0 doesn't anymore
  111. if Version(aiohttp.__version__) < Version("4.0.0"):
  112. self.http_session = aiohttp.ClientSession(loop=get_or_create_event_loop())
  113. else:
  114. self.http_session = aiohttp.ClientSession()
  115. @routes.get("/")
  116. async def get_index(self, req) -> aiohttp.web.FileResponse:
  117. try:
  118. # This API will be no-op after the first report.
  119. # Note: We always record the usage, but it is not reported
  120. # if the usage stats is disabled.
  121. record_extra_usage_tag(TagKey.DASHBOARD_USED, "True")
  122. except Exception as e:
  123. logger.warning(
  124. "Failed to record the dashboard usage. "
  125. "This error message is harmless and can be ignored. "
  126. f"Error: {e}"
  127. )
  128. resp = aiohttp.web.FileResponse(
  129. os.path.join(
  130. os.path.dirname(os.path.abspath(__file__)), "client/build/index.html"
  131. )
  132. )
  133. resp.headers["Cache-Control"] = "no-store"
  134. return resp
  135. @routes.get("/favicon.ico")
  136. async def get_favicon(self, req) -> aiohttp.web.FileResponse:
  137. return aiohttp.web.FileResponse(
  138. os.path.join(
  139. os.path.dirname(os.path.abspath(__file__)), "client/build/favicon.ico"
  140. )
  141. )
  142. @routes.get("/timezone")
  143. async def get_timezone(self, req) -> aiohttp.web.Response:
  144. try:
  145. current_timezone = timezone_utils.get_current_timezone_info()
  146. return aiohttp.web.json_response(current_timezone)
  147. except Exception as e:
  148. logger.error(f"Error getting timezone: {e}")
  149. return aiohttp.web.Response(
  150. status=500, text="Internal Server Error:" + str(e)
  151. )
  152. @routes.get("/api/authentication_mode")
  153. async def get_authentication_mode(self, req) -> aiohttp.web.Response:
  154. try:
  155. mode = get_authentication_mode()
  156. mode_str = auth_utils.get_authentication_mode_name(mode)
  157. response = aiohttp.web.json_response({"authentication_mode": mode_str})
  158. # If auth is disabled, clear any existing authentication cookie
  159. if mode_str == "disabled":
  160. response.set_cookie(
  161. authentication_constants.AUTHENTICATION_TOKEN_COOKIE_NAME,
  162. "",
  163. max_age=0,
  164. path="/",
  165. )
  166. return response
  167. except Exception as e:
  168. logger.error(f"Error getting authentication mode: {e}")
  169. return aiohttp.web.Response(
  170. status=500, text="Internal Server Error: " + str(e)
  171. )
  172. @routes.post("/api/authenticate")
  173. async def authenticate(self, req) -> aiohttp.web.Response:
  174. """
  175. Authenticate a user by validating their token and setting a secure HttpOnly cookie.
  176. This endpoint accepts a token via the Authorization header, validates it,
  177. and if valid, sets an HttpOnly cookie for subsequent requests from the web dashboard.
  178. """
  179. try:
  180. # Check if token authentication is enabled
  181. if not auth_utils.is_token_auth_enabled():
  182. return aiohttp.web.Response(
  183. status=401,
  184. text="Unauthorized: Token authentication is not enabled",
  185. )
  186. # Get token from Authorization header
  187. auth_header = req.headers.get(
  188. authentication_constants.AUTHORIZATION_HEADER_NAME, ""
  189. )
  190. if not auth_header:
  191. return aiohttp.web.Response(
  192. status=401,
  193. text="Unauthorized: Missing authentication token",
  194. )
  195. # Validate the token
  196. if not auth_utils.validate_request_token(auth_header):
  197. return aiohttp.web.Response(
  198. status=403,
  199. text="Forbidden: Invalid authentication token",
  200. )
  201. # Token is valid - extract the token value (remove "Bearer " prefix if present)
  202. token = auth_header
  203. if auth_header.lower().startswith(
  204. authentication_constants.AUTHORIZATION_BEARER_PREFIX.lower()
  205. ):
  206. token = auth_header[
  207. len(authentication_constants.AUTHORIZATION_BEARER_PREFIX) :
  208. ] # Remove "Bearer " prefix
  209. # Create successful response
  210. response = aiohttp.web.json_response(
  211. {"status": "authenticated", "message": "Token is valid"}
  212. )
  213. # Set secure HttpOnly cookie
  214. # Check if the connection is secure (HTTPS)
  215. is_secure = req.scheme == "https"
  216. response.set_cookie(
  217. authentication_constants.AUTHENTICATION_TOKEN_COOKIE_NAME,
  218. token,
  219. max_age=authentication_constants.AUTHENTICATION_TOKEN_COOKIE_MAX_AGE, # 30 days (matching previous behavior)
  220. path="/",
  221. httponly=True, # Prevents JavaScript access (XSS protection)
  222. samesite="Strict", # Prevents CSRF attacks
  223. secure=is_secure, # Only send over HTTPS if connection is secure
  224. )
  225. return response
  226. except Exception as e:
  227. logger.error(f"Error during authentication: {e}")
  228. return aiohttp.web.Response(
  229. status=500, text="Internal Server Error: " + str(e)
  230. )
  231. def get_address(self):
  232. assert self.http_host and self.http_port
  233. return self.http_host, self.http_port
  234. @aiohttp.web.middleware
  235. async def path_clean_middleware(self, request, handler):
  236. if request.path.startswith("/static") or request.path.startswith("/logs"):
  237. parent = pathlib.PurePosixPath(
  238. "/logs" if request.path.startswith("/logs") else "/static"
  239. )
  240. # If the destination is not relative to the expected directory,
  241. # then the user is attempting path traversal, so deny the request.
  242. request_path = pathlib.PurePosixPath(posixpath.realpath(request.path))
  243. if request_path != parent and parent not in request_path.parents:
  244. logger.info(
  245. f"Rejecting {request_path=} because it is not relative to {parent=}"
  246. )
  247. raise aiohttp.web.HTTPForbidden()
  248. return await handler(request)
  249. @aiohttp.web.middleware
  250. async def metrics_middleware(self, request, handler):
  251. start_time = time.monotonic()
  252. try:
  253. response = await handler(request)
  254. status_tag = f"{floor(response.status / 100)}xx"
  255. return response
  256. except (Exception, asyncio.CancelledError):
  257. status_tag = "5xx"
  258. raise
  259. finally:
  260. resp_time = time.monotonic() - start_time
  261. try:
  262. self.metrics.metrics_request_duration.labels(
  263. endpoint=handler.__name__,
  264. http_status=status_tag,
  265. Version=ray.__version__,
  266. SessionName=self._session_name,
  267. Component="dashboard",
  268. ).observe(resp_time)
  269. self.metrics.metrics_request_count.labels(
  270. method=request.method,
  271. endpoint=handler.__name__,
  272. http_status=status_tag,
  273. Version=ray.__version__,
  274. SessionName=self._session_name,
  275. Component="dashboard",
  276. ).inc()
  277. except Exception as e:
  278. logger.exception(f"Error emitting api metrics: {e}")
  279. @aiohttp.web.middleware
  280. async def cache_control_static_middleware(self, request, handler):
  281. if request.path.startswith("/static"):
  282. response = await handler(request)
  283. response.headers["Cache-Control"] = "max-age=31536000"
  284. return response
  285. return await handler(request)
  286. async def run(
  287. self,
  288. dashboard_head_modules: List[DashboardHeadModule],
  289. subprocess_module_handles: List[SubprocessModuleHandle],
  290. ):
  291. # Bind http routes of each module.
  292. for m in dashboard_head_modules:
  293. dashboard_optional_utils.DashboardHeadRouteTable.bind(m)
  294. for h in subprocess_module_handles:
  295. SubprocessRouteTable.bind(h)
  296. # Public endpoints that don't require authentication.
  297. # These are needed for the dashboard to load and request an auth token.
  298. public_exact_paths = {
  299. "/", # Root index.html
  300. "/favicon.ico",
  301. "/api/authentication_mode",
  302. "/api/authenticate", # Token authentication endpoint
  303. "/api/healthz", # General healthcheck
  304. "/api/gcs_healthz", # GCS health check
  305. "/api/local_raylet_healthz", # Raylet health check
  306. "/-/healthz", # Serve health check
  307. }
  308. public_path_prefixes = ("/static/",) # Static assets (JS, CSS, images)
  309. # Http server should be initialized after all modules loaded.
  310. # working_dir uploads for job submission can be up to 100MiB.
  311. app = aiohttp.web.Application(
  312. client_max_size=ray_constants.DASHBOARD_CLIENT_MAX_SIZE,
  313. middlewares=[
  314. self.metrics_middleware,
  315. get_token_auth_middleware(
  316. aiohttp, public_exact_paths, public_path_prefixes
  317. ),
  318. self.path_clean_middleware,
  319. dashboard_optional_utils.get_browser_request_middleware(
  320. aiohttp,
  321. allowed_methods={"GET", "HEAD", "OPTIONS"},
  322. allowed_paths=["/api/authenticate"],
  323. ),
  324. self.cache_control_static_middleware,
  325. ],
  326. )
  327. app.add_routes(routes=routes.bound_routes())
  328. app.add_routes(routes=SubprocessRouteTable.bound_routes())
  329. self.runner = aiohttp.web.AppRunner(
  330. app,
  331. access_log_format=(
  332. "%a %t '%r' %s %b bytes %D us '%{Referer}i' '%{User-Agent}i'"
  333. ),
  334. )
  335. await self.runner.setup()
  336. last_ex = None
  337. for i in range(1 + self.http_port_retries):
  338. try:
  339. site = aiohttp.web.TCPSite(self.runner, self.http_host, self.http_port)
  340. await site.start()
  341. break
  342. except OSError as e:
  343. last_ex = e
  344. self.http_port += 1
  345. logger.warning("Try to use port %s: %s", self.http_port, e)
  346. else:
  347. raise Exception(
  348. f"Failed to find a valid port for dashboard after "
  349. f"{self.http_port_retries} retries: {last_ex}"
  350. )
  351. self.http_host, self.http_port, *_ = site._server.sockets[0].getsockname()
  352. self.http_host = (
  353. self.ip
  354. if ipaddress.ip_address(self.http_host).is_unspecified
  355. else self.http_host
  356. )
  357. logger.info(
  358. "Dashboard head http address: %s",
  359. build_address(self.http_host, self.http_port),
  360. )
  361. # Dump registered http routes.
  362. dump_routes = [r for r in app.router.routes() if r.method != hdrs.METH_HEAD]
  363. for r in dump_routes:
  364. logger.info(r)
  365. logger.info("Registered %s routes.", len(dump_routes))
  366. async def cleanup(self):
  367. # Wait for finish signal.
  368. await self.runner.cleanup()