train_head.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import logging
  2. from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
  3. from aiohttp.web import Request, Response
  4. import ray
  5. import ray.dashboard.optional_utils as dashboard_optional_utils
  6. from ray.core.generated import gcs_service_pb2_grpc
  7. from ray.dashboard.modules.job.common import JobInfoStorageClient
  8. from ray.dashboard.modules.job.utils import find_jobs_by_job_ids
  9. from ray.dashboard.subprocesses.module import SubprocessModule
  10. from ray.dashboard.subprocesses.routes import SubprocessRouteTable as routes
  11. from ray.dashboard.subprocesses.utils import get_http_session_to_module
  12. from ray.util.annotations import DeveloperAPI
  13. if TYPE_CHECKING:
  14. from ray.dashboard.modules.job.pydantic_models import JobDetails
  15. from ray.train.v2._internal.state.schema import (
  16. DecoratedTrainRun,
  17. DecoratedTrainRunAttempt,
  18. DecoratedTrainWorker,
  19. RunStatus,
  20. TrainRun,
  21. TrainRunAttempt,
  22. TrainWorker,
  23. )
  24. logger = logging.getLogger(__name__)
  25. logger.setLevel(logging.INFO)
  26. class TrainHead(SubprocessModule):
  27. def __init__(self, *args, **kwargs):
  28. super().__init__(*args, **kwargs)
  29. self._train_stats_actor = None # Train V1
  30. self._train_v2_state_actor = None # Train V2
  31. self._job_info_client = None
  32. self._gcs_actor_info_stub = None
  33. # Lazy initialized HTTP session to NodeHead
  34. self._node_head_http_session = None
  35. # TODO: The next iteration of this should be "/api/train/v2/runs/v2".
  36. # This follows the naming convention of "/api/train/{train_version}/runs/{api_version}".
  37. # This API corresponds to the Train V2 API.
  38. @routes.get("/api/train/v2/runs/v1")
  39. @dashboard_optional_utils.init_ray_and_catch_exceptions()
  40. @DeveloperAPI
  41. async def get_train_v2_runs(self, req: Request) -> Response:
  42. """Get all TrainRuns for Ray Train V2."""
  43. try:
  44. from ray.train.v2._internal.state.schema import TrainRunsResponse
  45. except ImportError:
  46. logger.exception(
  47. "Train is not installed. Please run `pip install ray[train]` "
  48. "when setting up Ray on your cluster."
  49. )
  50. return Response(
  51. status=500,
  52. text="Train is not installed. Please run `pip install ray[train]` "
  53. "when setting up Ray on your cluster.",
  54. )
  55. state_actor = await self.get_train_v2_state_actor()
  56. if state_actor is None:
  57. return Response(
  58. status=500,
  59. text=(
  60. "Train state data is not available. Please make sure Ray Train "
  61. "is running and that the Train state actor is enabled by setting "
  62. 'the RAY_TRAIN_ENABLE_STATE_TRACKING environment variable to "1".'
  63. ),
  64. )
  65. else:
  66. try:
  67. train_runs = await state_actor.get_train_runs.remote()
  68. decorated_train_runs = await self._decorate_train_runs(
  69. train_runs.values()
  70. )
  71. details = TrainRunsResponse(train_runs=decorated_train_runs)
  72. except ray.exceptions.RayTaskError as e:
  73. # Task failure sometimes are due to GCS
  74. # failure. When GCS failed, we expect a longer time
  75. # to recover.
  76. return Response(
  77. status=503,
  78. text=(
  79. "Failed to get a response from the train stats actor. "
  80. f"The GCS may be down, please retry later: {e}"
  81. ),
  82. )
  83. return Response(
  84. text=details.json(),
  85. content_type="application/json",
  86. )
  87. async def _decorate_train_runs(
  88. self, train_runs: List["TrainRun"]
  89. ) -> List["DecoratedTrainRun"]:
  90. """Decorate the train runs with run attempts, job details, status, and status details.
  91. Returns:
  92. List[DecoratedTrainRun]: The decorated train runs in reverse chronological order.
  93. """
  94. from ray.train.v2._internal.state.schema import DecoratedTrainRun
  95. decorated_train_runs: List[DecoratedTrainRun] = []
  96. state_actor = await self.get_train_v2_state_actor()
  97. all_train_run_attempts = await state_actor.get_train_run_attempts.remote()
  98. jobs = await self._get_jobs([train_run.job_id for train_run in train_runs])
  99. for train_run in train_runs:
  100. # TODO: Batch these together across TrainRuns if needed.
  101. train_run_attempts = all_train_run_attempts[train_run.id].values()
  102. decorated_train_run_attempts: List[
  103. DecoratedTrainRunAttempt
  104. ] = await self._decorate_train_run_attempts(train_run_attempts)
  105. job_details = jobs[train_run.job_id]
  106. status, status_details = await self._get_run_status(train_run)
  107. decorated_train_run = DecoratedTrainRun.parse_obj(
  108. {
  109. **train_run.dict(),
  110. "attempts": decorated_train_run_attempts,
  111. "job_details": job_details,
  112. "status": status,
  113. "status_detail": status_details,
  114. }
  115. )
  116. decorated_train_runs.append(decorated_train_run)
  117. # Sort train runs in reverse chronological order
  118. decorated_train_runs = sorted(
  119. decorated_train_runs,
  120. key=lambda run: run.start_time_ns,
  121. reverse=True,
  122. )
  123. return decorated_train_runs
  124. async def _get_jobs(self, job_ids: List[str]) -> Dict[str, "JobDetails"]:
  125. return await find_jobs_by_job_ids(
  126. self.gcs_client,
  127. self._job_info_client,
  128. job_ids,
  129. )
  130. async def _decorate_train_run_attempts(
  131. self, train_run_attempts: List["TrainRunAttempt"]
  132. ) -> List["DecoratedTrainRunAttempt"]:
  133. from ray.train.v2._internal.state.schema import DecoratedTrainRunAttempt
  134. decorated_train_run_attempts: List[DecoratedTrainRunAttempt] = []
  135. for train_run_attempt in train_run_attempts:
  136. # TODO: Batch these together across TrainRunAttempts if needed.
  137. decorated_train_workers: List[
  138. DecoratedTrainWorker
  139. ] = await self._decorate_train_workers(train_run_attempt.workers)
  140. decorated_train_run_attempt = DecoratedTrainRunAttempt.parse_obj(
  141. {**train_run_attempt.dict(), "workers": decorated_train_workers}
  142. )
  143. decorated_train_run_attempts.append(decorated_train_run_attempt)
  144. return decorated_train_run_attempts
  145. async def _decorate_train_workers(
  146. self, train_workers: List["TrainWorker"]
  147. ) -> List["DecoratedTrainWorker"]:
  148. from ray.train.v2._internal.state.schema import DecoratedTrainWorker
  149. decorated_train_workers: List[DecoratedTrainWorker] = []
  150. actor_ids = [worker.actor_id for worker in train_workers]
  151. logger.info(f"Getting all actor info from GCS (actor_ids={actor_ids})")
  152. train_run_actors = await self._get_actor_infos(actor_ids)
  153. for train_worker in train_workers:
  154. actor = train_run_actors.get(train_worker.actor_id, None)
  155. # Add hardware metrics to API response
  156. if actor:
  157. gpus = [
  158. gpu
  159. for gpu in actor["gpus"]
  160. if train_worker.pid
  161. in [process["pid"] for process in gpu["processesPids"]]
  162. ]
  163. # Need to convert processesPids into a proper list.
  164. # It's some weird ImmutableList structure
  165. # We also convert the list of processes into a single item since
  166. # an actor is only a single process and cannot match multiple
  167. # processes.
  168. formatted_gpus = [
  169. {
  170. **gpu,
  171. "processInfo": [
  172. process
  173. for process in gpu["processesPids"]
  174. if process["pid"] == train_worker.pid
  175. ][0],
  176. }
  177. for gpu in gpus
  178. ]
  179. decorated_train_worker = DecoratedTrainWorker.parse_obj(
  180. {
  181. **train_worker.dict(),
  182. "status": actor["state"],
  183. "processStats": actor["processStats"],
  184. "gpus": formatted_gpus,
  185. }
  186. )
  187. else:
  188. decorated_train_worker = DecoratedTrainWorker.parse_obj(
  189. train_worker.dict()
  190. )
  191. decorated_train_workers.append(decorated_train_worker)
  192. return decorated_train_workers
  193. async def _get_run_status(
  194. self, train_run: "TrainRun"
  195. ) -> Tuple["RunStatus", Optional[str]]:
  196. from ray.train.v2._internal.state.schema import ActorStatus, RunStatus
  197. # TODO: Move this to the TrainStateActor.
  198. # The train run can be unexpectedly terminated before the final run
  199. # status was updated. This could be due to errors outside of the training
  200. # function (e.g., system failure or user interruption) that crashed the
  201. # train controller.
  202. # We need to detect this case and mark the train run as ABORTED.
  203. actor_infos = await self._get_actor_infos([train_run.controller_actor_id])
  204. controller_actor_info = actor_infos[train_run.controller_actor_id]
  205. controller_actor_status = (
  206. controller_actor_info.get("state") if controller_actor_info else None
  207. )
  208. if (
  209. controller_actor_status == ActorStatus.DEAD
  210. and train_run.status == RunStatus.RUNNING
  211. ):
  212. run_status = RunStatus.ABORTED
  213. status_detail = "Terminated due to system errors or killed by the user."
  214. return (run_status, status_detail)
  215. # Default to original.
  216. return (train_run.status, train_run.status_detail)
  217. # TODO: The next iteration of this should be "/api/train/v1/runs/v3".
  218. # This follows the naming convention of "/api/train/{train_version}/runs/{api_version}".
  219. # This API corresponds to the Train V1 API.
  220. @routes.get("/api/train/v2/runs")
  221. @dashboard_optional_utils.init_ray_and_catch_exceptions()
  222. @DeveloperAPI
  223. async def get_train_runs(self, req: Request) -> Response:
  224. """Get all TrainRunInfos for Ray Train V1."""
  225. try:
  226. from ray.train._internal.state.schema import TrainRunsResponse
  227. except ImportError:
  228. logger.exception(
  229. "Train is not installed. Please run `pip install ray[train]` "
  230. "when setting up Ray on your cluster."
  231. )
  232. return Response(
  233. status=500,
  234. text="Train is not installed. Please run `pip install ray[train]` "
  235. "when setting up Ray on your cluster.",
  236. )
  237. stats_actor = await self.get_train_stats_actor()
  238. if stats_actor is None:
  239. return Response(
  240. status=500,
  241. text=(
  242. "Train state data is not available. Please make sure Ray Train "
  243. "is running and that the Train state actor is enabled by setting "
  244. 'the RAY_TRAIN_ENABLE_STATE_TRACKING environment variable to "1".'
  245. ),
  246. )
  247. else:
  248. try:
  249. train_runs = await stats_actor.get_all_train_runs.remote()
  250. train_runs_with_details = (
  251. await self._add_actor_status_and_update_run_status(train_runs)
  252. )
  253. # Sort train runs in reverse chronological order
  254. train_runs_with_details = sorted(
  255. train_runs_with_details,
  256. key=lambda run: run.start_time_ms,
  257. reverse=True,
  258. )
  259. job_details = await find_jobs_by_job_ids(
  260. self.gcs_client,
  261. self._job_info_client,
  262. [run.job_id for run in train_runs_with_details],
  263. )
  264. for run in train_runs_with_details:
  265. run.job_details = job_details.get(run.job_id)
  266. details = TrainRunsResponse(train_runs=train_runs_with_details)
  267. except ray.exceptions.RayTaskError as e:
  268. # Task failure sometimes are due to GCS
  269. # failure. When GCS failed, we expect a longer time
  270. # to recover.
  271. return Response(
  272. status=503,
  273. text=(
  274. "Failed to get a response from the train stats actor. "
  275. f"The GCS may be down, please retry later: {e}"
  276. ),
  277. )
  278. return Response(
  279. text=details.json(),
  280. content_type="application/json",
  281. )
  282. async def _get_actor_infos(self, actor_ids: List[str]):
  283. if self._node_head_http_session is None:
  284. self._node_head_http_session = get_http_session_to_module(
  285. "NodeHead", self._config.socket_dir, self._config.session_name
  286. )
  287. actor_ids_qs_str = ",".join(actor_ids)
  288. url = f"http://localhost/logical/actors?ids={actor_ids_qs_str}&nocache=1"
  289. async with self._node_head_http_session.get(url) as resp:
  290. resp.raise_for_status()
  291. resp_json = await resp.json()
  292. return resp_json["data"]["actors"]
  293. async def _add_actor_status_and_update_run_status(self, train_runs):
  294. from ray.train._internal.state.schema import (
  295. ActorStatusEnum,
  296. RunStatusEnum,
  297. TrainRunInfoWithDetails,
  298. TrainWorkerInfoWithDetails,
  299. )
  300. train_runs_with_details: List[TrainRunInfoWithDetails] = []
  301. for train_run in train_runs.values():
  302. worker_infos_with_details: List[TrainWorkerInfoWithDetails] = []
  303. actor_ids = [worker.actor_id for worker in train_run.workers]
  304. logger.info(f"Getting all actor info from GCS (actor_ids={actor_ids})")
  305. train_run_actors = await self._get_actor_infos(actor_ids)
  306. for worker_info in train_run.workers:
  307. actor = train_run_actors.get(worker_info.actor_id, None)
  308. # Add hardware metrics to API response
  309. if actor:
  310. gpus = [
  311. gpu
  312. for gpu in actor["gpus"]
  313. if worker_info.pid
  314. in [process["pid"] for process in gpu["processesPids"]]
  315. ]
  316. # Need to convert processesPids into a proper list.
  317. # It's some weird ImmutableList structureo
  318. # We also convert the list of processes into a single item since
  319. # an actor is only a single process and cannot match multiple
  320. # processes.
  321. formatted_gpus = [
  322. {
  323. **gpu,
  324. "processInfo": [
  325. process
  326. for process in gpu["processesPids"]
  327. if process["pid"] == worker_info.pid
  328. ][0],
  329. }
  330. for gpu in gpus
  331. ]
  332. worker_info_with_details = TrainWorkerInfoWithDetails.parse_obj(
  333. {
  334. **worker_info.dict(),
  335. "status": actor["state"],
  336. "processStats": actor["processStats"],
  337. "gpus": formatted_gpus,
  338. }
  339. )
  340. else:
  341. worker_info_with_details = TrainWorkerInfoWithDetails.parse_obj(
  342. worker_info.dict()
  343. )
  344. worker_infos_with_details.append(worker_info_with_details)
  345. train_run_with_details = TrainRunInfoWithDetails.parse_obj(
  346. {**train_run.dict(), "workers": worker_infos_with_details}
  347. )
  348. # The train run can be unexpectedly terminated before the final run
  349. # status was updated. This could be due to errors outside of the training
  350. # function (e.g., system failure or user interruption) that crashed the
  351. # train controller.
  352. # We need to detect this case and mark the train run as ABORTED.
  353. actor = train_run_actors.get(train_run.controller_actor_id)
  354. controller_actor_status = actor.get("state") if actor else None
  355. if (
  356. controller_actor_status == ActorStatusEnum.DEAD
  357. and train_run.run_status == RunStatusEnum.RUNNING
  358. ):
  359. train_run_with_details.run_status = RunStatusEnum.ABORTED
  360. train_run_with_details.status_detail = (
  361. "Terminated due to system errors or killed by the user."
  362. )
  363. train_runs_with_details.append(train_run_with_details)
  364. return train_runs_with_details
  365. async def run(self):
  366. await super().run()
  367. if not self._job_info_client:
  368. self._job_info_client = JobInfoStorageClient(self.gcs_client)
  369. gcs_channel = self.aiogrpc_gcs_channel
  370. self._gcs_actor_info_stub = gcs_service_pb2_grpc.ActorInfoGcsServiceStub(
  371. gcs_channel
  372. )
  373. async def get_train_stats_actor(self):
  374. """
  375. Gets the train stats actor and caches it as an instance variable.
  376. """
  377. try:
  378. from ray.train._internal.state.state_actor import get_state_actor
  379. if self._train_stats_actor is None:
  380. self._train_stats_actor = get_state_actor()
  381. return self._train_stats_actor
  382. except ImportError:
  383. logger.exception(
  384. "Train is not installed. Please run `pip install ray[train]` "
  385. "when setting up Ray on your cluster."
  386. )
  387. return None
  388. async def get_train_v2_state_actor(self):
  389. """
  390. Gets the Train state actor and caches it as an instance variable.
  391. """
  392. try:
  393. from ray.train.v2._internal.state.state_actor import get_state_actor
  394. if self._train_v2_state_actor is None:
  395. self._train_v2_state_actor = get_state_actor()
  396. return self._train_v2_state_actor
  397. except ImportError:
  398. logger.exception(
  399. "Train is not installed. Please run `pip install ray[train]` "
  400. "when setting up Ray on your cluster."
  401. )
  402. return None