cli.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. import json
  2. import os
  3. import pprint
  4. import sys
  5. import time
  6. from subprocess import list2cmdline
  7. from typing import Any, Dict, Optional, Tuple, Union
  8. import click
  9. import ray._private.ray_constants as ray_constants
  10. from ray._common.utils import (
  11. get_or_create_event_loop,
  12. load_class,
  13. )
  14. from ray._private.utils import (
  15. parse_metadata_json,
  16. parse_resources_json,
  17. )
  18. from ray.autoscaler._private.cli_logger import add_click_logging_options, cf, cli_logger
  19. from ray.dashboard.modules.dashboard_sdk import parse_runtime_env_args
  20. from ray.dashboard.modules.job.cli_utils import add_common_job_options
  21. from ray.dashboard.modules.job.utils import redact_url_password
  22. from ray.job_submission import JobStatus, JobSubmissionClient
  23. from ray.util.annotations import PublicAPI
  24. def _get_sdk_client(
  25. address: Optional[str],
  26. create_cluster_if_needed: bool = False,
  27. headers: Optional[str] = None,
  28. verify: Union[bool, str] = True,
  29. ) -> JobSubmissionClient:
  30. client = JobSubmissionClient(
  31. address,
  32. create_cluster_if_needed,
  33. headers=_handle_headers(headers),
  34. verify=verify,
  35. )
  36. client_address = client.get_address()
  37. cli_logger.labeled_value(
  38. "Job submission server address", redact_url_password(client_address)
  39. )
  40. return client
  41. def _handle_headers(headers: Optional[str]) -> Optional[Dict[str, Any]]:
  42. if headers is None and "RAY_JOB_HEADERS" in os.environ:
  43. headers = os.environ["RAY_JOB_HEADERS"]
  44. if headers is not None:
  45. try:
  46. return json.loads(headers)
  47. except Exception as exc:
  48. raise ValueError(
  49. """Failed to parse headers into JSON.
  50. Expected format: {{"KEY": "VALUE"}}, got {}, {}""".format(
  51. headers, exc
  52. )
  53. )
  54. return None
  55. def _log_big_success_msg(success_msg):
  56. cli_logger.newline()
  57. cli_logger.success("-" * len(success_msg))
  58. cli_logger.success(success_msg)
  59. cli_logger.success("-" * len(success_msg))
  60. cli_logger.newline()
  61. def _log_big_error_msg(success_msg):
  62. cli_logger.newline()
  63. cli_logger.error("-" * len(success_msg))
  64. cli_logger.error(success_msg)
  65. cli_logger.error("-" * len(success_msg))
  66. cli_logger.newline()
  67. def _log_job_status(client: JobSubmissionClient, job_id: str) -> JobStatus:
  68. info = client.get_job_info(job_id)
  69. if info.status == JobStatus.SUCCEEDED:
  70. _log_big_success_msg(f"Job '{job_id}' succeeded")
  71. elif info.status == JobStatus.STOPPED:
  72. cli_logger.warning(f"Job '{job_id}' was stopped")
  73. elif info.status == JobStatus.FAILED:
  74. _log_big_error_msg(f"Job '{job_id}' failed")
  75. if info.message is not None:
  76. cli_logger.print(f"Status message: {info.message}", no_format=True)
  77. else:
  78. # Catch-all.
  79. cli_logger.print(f"Status for job '{job_id}': {info.status}")
  80. if info.message is not None:
  81. cli_logger.print(f"Status message: {info.message}", no_format=True)
  82. return info.status
  83. async def _tail_logs(client: JobSubmissionClient, job_id: str) -> JobStatus:
  84. async for lines in client.tail_job_logs(job_id):
  85. print(lines, end="")
  86. return _log_job_status(client, job_id)
  87. @click.group("job")
  88. def job_cli_group():
  89. """Submit, stop, delete, or list Ray jobs."""
  90. pass
  91. @job_cli_group.command()
  92. @click.option(
  93. "--address",
  94. type=str,
  95. default=None,
  96. required=False,
  97. help=(
  98. "Address of the Ray cluster to connect to. Can also be specified "
  99. "using the RAY_API_SERVER_ADDRESS environment variable (falls back to RAY_ADDRESS)."
  100. ),
  101. )
  102. @click.option(
  103. "--job-id",
  104. type=str,
  105. default=None,
  106. required=False,
  107. help=("DEPRECATED: Use `--submission-id` instead."),
  108. )
  109. @click.option(
  110. "--submission-id",
  111. type=str,
  112. default=None,
  113. required=False,
  114. help=(
  115. "Submission ID to specify for the job. "
  116. "If not provided, one will be generated."
  117. ),
  118. )
  119. @click.option(
  120. "--runtime-env",
  121. type=str,
  122. default=None,
  123. required=False,
  124. help="Path to a local YAML file containing a runtime_env definition.",
  125. )
  126. @click.option(
  127. "--runtime-env-json",
  128. type=str,
  129. default=None,
  130. required=False,
  131. help="JSON-serialized runtime_env dictionary.",
  132. )
  133. @click.option(
  134. "--working-dir",
  135. type=str,
  136. default=None,
  137. required=False,
  138. help=(
  139. "Directory containing files that your job will run in. Can be a "
  140. "local directory or a remote URI to a .zip file (S3, GS, HTTP). "
  141. "If specified, this overrides the option in `--runtime-env`."
  142. ),
  143. )
  144. @click.option(
  145. "--metadata-json",
  146. type=str,
  147. default=None,
  148. required=False,
  149. help="JSON-serialized dictionary of metadata to attach to the job.",
  150. )
  151. @click.option(
  152. "--entrypoint-num-cpus",
  153. required=False,
  154. type=float,
  155. help="the quantity of CPU cores to reserve for the entrypoint command, "
  156. "separately from any tasks or actors that are launched by it",
  157. )
  158. @click.option(
  159. "--entrypoint-num-gpus",
  160. required=False,
  161. type=float,
  162. help="the quantity of GPUs to reserve for the entrypoint command, "
  163. "separately from any tasks or actors that are launched by it",
  164. )
  165. @click.option(
  166. "--entrypoint-memory",
  167. required=False,
  168. type=int,
  169. help="the amount of memory to reserve "
  170. "for the entrypoint command, separately from any tasks or actors that are "
  171. "launched by it",
  172. )
  173. @click.option(
  174. "--entrypoint-resources",
  175. required=False,
  176. type=str,
  177. help="a JSON-serialized dictionary mapping resource name to resource quantity "
  178. "describing resources to reserve for the entrypoint command, "
  179. "separately from any tasks or actors that are launched by it",
  180. )
  181. @click.option(
  182. "--entrypoint-label-selector",
  183. required=False,
  184. type=str,
  185. help="a JSON-serialized dictionary mapping label keys to selector strings "
  186. "describing placement constraints for the entrypoint command",
  187. )
  188. @click.option(
  189. "--no-wait",
  190. is_flag=True,
  191. type=bool,
  192. default=False,
  193. help="If set, will not stream logs and wait for the job to exit.",
  194. )
  195. @add_common_job_options
  196. @add_click_logging_options
  197. @click.argument("entrypoint", nargs=-1, required=True, type=click.UNPROCESSED)
  198. @PublicAPI
  199. def submit(
  200. address: Optional[str],
  201. job_id: Optional[str],
  202. submission_id: Optional[str],
  203. runtime_env: Optional[str],
  204. runtime_env_json: Optional[str],
  205. metadata_json: Optional[str],
  206. working_dir: Optional[str],
  207. entrypoint: Tuple[str],
  208. entrypoint_num_cpus: Optional[Union[int, float]],
  209. entrypoint_num_gpus: Optional[Union[int, float]],
  210. entrypoint_memory: Optional[int],
  211. entrypoint_resources: Optional[str],
  212. entrypoint_label_selector: Optional[str],
  213. no_wait: bool,
  214. verify: Union[bool, str],
  215. headers: Optional[str],
  216. ):
  217. """Submits a job to be run on the cluster.
  218. By default (if --no-wait is not set), streams logs to stdout until the job finishes.
  219. If the job succeeded, exits with 0. If it failed, exits with 1.
  220. Example:
  221. `ray job submit -- python my_script.py --arg=val`
  222. Args:
  223. address: Job submission server address.
  224. job_id: DEPRECATED. Use submission_id instead.
  225. submission_id: Submission ID for the job.
  226. runtime_env: Path to a runtime_env YAML file.
  227. runtime_env_json: JSON-serialized runtime_env dictionary.
  228. metadata_json: JSON-serialized metadata dictionary.
  229. working_dir: Working directory for the job.
  230. entrypoint: Entrypoint command.
  231. entrypoint_num_cpus: CPU cores to reserve.
  232. entrypoint_num_gpus: GPUs to reserve.
  233. entrypoint_memory: Memory to reserve.
  234. entrypoint_resources: JSON-serialized custom resources dict.
  235. entrypoint_label_selector: JSON-serialized label selector dict.
  236. no_wait: Do not wait for job completion.
  237. verify: TLS verification flag or path.
  238. headers: JSON-serialized headers.
  239. """
  240. if job_id:
  241. cli_logger.warning(
  242. "--job-id option is deprecated. Please use --submission-id instead."
  243. )
  244. if entrypoint_resources is not None:
  245. entrypoint_resources = parse_resources_json(
  246. entrypoint_resources, cli_logger, cf, command_arg="entrypoint-resources"
  247. )
  248. if entrypoint_label_selector is not None:
  249. entrypoint_label_selector = parse_resources_json(
  250. entrypoint_label_selector,
  251. cli_logger,
  252. cf,
  253. command_arg="entrypoint-label-selector",
  254. )
  255. if metadata_json is not None:
  256. metadata_json = parse_metadata_json(
  257. metadata_json, cli_logger, cf, command_arg="metadata-json"
  258. )
  259. submission_id = submission_id or job_id
  260. if ray_constants.RAY_JOB_SUBMIT_HOOK in os.environ:
  261. # Submit all args as **kwargs per the JOB_SUBMIT_HOOK contract.
  262. load_class(os.environ[ray_constants.RAY_JOB_SUBMIT_HOOK])(
  263. address=address,
  264. job_id=submission_id,
  265. submission_id=submission_id,
  266. runtime_env=runtime_env,
  267. runtime_env_json=runtime_env_json,
  268. metadata_json=metadata_json,
  269. working_dir=working_dir,
  270. entrypoint=entrypoint,
  271. entrypoint_num_cpus=entrypoint_num_cpus,
  272. entrypoint_num_gpus=entrypoint_num_gpus,
  273. entrypoint_memory=entrypoint_memory,
  274. entrypoint_resources=entrypoint_resources,
  275. entrypoint_label_selector=entrypoint_label_selector,
  276. no_wait=no_wait,
  277. )
  278. client = _get_sdk_client(
  279. address, create_cluster_if_needed=True, headers=headers, verify=verify
  280. )
  281. final_runtime_env = parse_runtime_env_args(
  282. runtime_env=runtime_env,
  283. runtime_env_json=runtime_env_json,
  284. working_dir=working_dir,
  285. )
  286. job_id = client.submit_job(
  287. entrypoint=list2cmdline(entrypoint),
  288. submission_id=submission_id,
  289. runtime_env=final_runtime_env,
  290. metadata=metadata_json,
  291. entrypoint_num_cpus=entrypoint_num_cpus,
  292. entrypoint_num_gpus=entrypoint_num_gpus,
  293. entrypoint_memory=entrypoint_memory,
  294. entrypoint_resources=entrypoint_resources,
  295. entrypoint_label_selector=entrypoint_label_selector,
  296. )
  297. _log_big_success_msg(f"Job '{job_id}' submitted successfully")
  298. with cli_logger.group("Next steps"):
  299. cli_logger.print("Query the logs of the job:")
  300. with cli_logger.indented():
  301. cli_logger.print(cf.bold(f"ray job logs {job_id}"))
  302. cli_logger.print("Query the status of the job:")
  303. with cli_logger.indented():
  304. cli_logger.print(cf.bold(f"ray job status {job_id}"))
  305. cli_logger.print("Request the job to be stopped:")
  306. with cli_logger.indented():
  307. cli_logger.print(cf.bold(f"ray job stop {job_id}"))
  308. cli_logger.newline()
  309. # Flush stdout to ensure the Ray job ID is output immediately
  310. # for the kubectl plugin, ref PR #52780, Issue kuberay/#3508.
  311. cli_logger.flush()
  312. sdk_version = client.get_version()
  313. # sdk version 0 does not have log streaming
  314. if not no_wait:
  315. if int(sdk_version) > 0:
  316. cli_logger.print(
  317. "Tailing logs until the job exits (disable with --no-wait):"
  318. )
  319. job_status = get_or_create_event_loop().run_until_complete(
  320. _tail_logs(client, job_id)
  321. )
  322. if job_status == JobStatus.FAILED:
  323. sys.exit(1)
  324. else:
  325. cli_logger.warning(
  326. "Tailing logs is not enabled for job sdk client version "
  327. f"{sdk_version}. Please upgrade Ray to the latest version "
  328. "for this feature."
  329. )
  330. @job_cli_group.command()
  331. @click.option(
  332. "--address",
  333. type=str,
  334. default=None,
  335. required=False,
  336. help=(
  337. "Address of the Ray cluster to connect to. Can also be specified "
  338. "using the RAY_API_SERVER_ADDRESS environment variable (falls back to RAY_ADDRESS)."
  339. ),
  340. )
  341. @click.argument("job-id", type=str)
  342. @add_common_job_options
  343. @add_click_logging_options
  344. @PublicAPI(stability="stable")
  345. def status(
  346. address: Optional[str],
  347. job_id: str,
  348. headers: Optional[str],
  349. verify: Union[bool, str],
  350. ):
  351. """Queries for the current status of a job.
  352. Example:
  353. `ray job status <my_job_id>`
  354. """
  355. client = _get_sdk_client(address, headers=headers, verify=verify)
  356. _log_job_status(client, job_id)
  357. @job_cli_group.command()
  358. @click.option(
  359. "--address",
  360. type=str,
  361. default=None,
  362. required=False,
  363. help=(
  364. "Address of the Ray cluster to connect to. Can also be specified "
  365. "using the RAY_API_SERVER_ADDRESS environment variable (falls back to RAY_ADDRESS)."
  366. ),
  367. )
  368. @click.option(
  369. "--no-wait",
  370. is_flag=True,
  371. type=bool,
  372. default=False,
  373. help="If set, will not wait for the job to exit.",
  374. )
  375. @click.argument("job-id", type=str)
  376. @add_common_job_options
  377. @add_click_logging_options
  378. @PublicAPI(stability="stable")
  379. def stop(
  380. address: Optional[str],
  381. no_wait: bool,
  382. job_id: str,
  383. headers: Optional[str],
  384. verify: Union[bool, str],
  385. ):
  386. """Attempts to stop a job.
  387. Example:
  388. `ray job stop <my_job_id>`
  389. """
  390. client = _get_sdk_client(address, headers=headers, verify=verify)
  391. cli_logger.print(f"Attempting to stop job '{job_id}'")
  392. client.stop_job(job_id)
  393. if no_wait:
  394. return
  395. else:
  396. cli_logger.print(
  397. f"Waiting for job '{job_id}' to exit " f"(disable with --no-wait):"
  398. )
  399. while True:
  400. status = client.get_job_status(job_id)
  401. if status in {JobStatus.STOPPED, JobStatus.SUCCEEDED, JobStatus.FAILED}:
  402. _log_job_status(client, job_id)
  403. break
  404. else:
  405. cli_logger.print(f"Job has not exited yet. Status: {status}")
  406. time.sleep(1)
  407. @job_cli_group.command()
  408. @click.option(
  409. "--address",
  410. type=str,
  411. default=None,
  412. required=False,
  413. help=(
  414. "Address of the Ray cluster to connect to. Can also be specified "
  415. "using the RAY_API_SERVER_ADDRESS environment variable (falls back to RAY_ADDRESS)."
  416. ),
  417. )
  418. @click.argument("job-id", type=str)
  419. @add_common_job_options
  420. @add_click_logging_options
  421. @PublicAPI(stability="stable")
  422. def delete(
  423. address: Optional[str],
  424. job_id: str,
  425. headers: Optional[str],
  426. verify: Union[bool, str],
  427. ):
  428. """Deletes a stopped job and its associated data from memory.
  429. Only supported for jobs that are already in a terminal state.
  430. Fails with exit code 1 if the job is not already stopped.
  431. Does not delete job logs from disk.
  432. Submitting a job with the same submission ID as a previously
  433. deleted job is not supported and may lead to unexpected behavior.
  434. Example:
  435. ray job delete <my_job_id>
  436. """
  437. client = _get_sdk_client(address, headers=headers, verify=verify)
  438. client.delete_job(job_id)
  439. cli_logger.print(f"Job '{job_id}' deleted successfully")
  440. @job_cli_group.command()
  441. @click.option(
  442. "--address",
  443. type=str,
  444. default=None,
  445. required=False,
  446. help=(
  447. "Address of the Ray cluster to connect to. Can also be specified "
  448. "using the RAY_API_SERVER_ADDRESS environment variable (falls back to RAY_ADDRESS)."
  449. ),
  450. )
  451. @click.argument("job-id", type=str)
  452. @click.option(
  453. "-f",
  454. "--follow",
  455. is_flag=True,
  456. type=bool,
  457. default=False,
  458. help="If set, follow the logs (like `tail -f`).",
  459. )
  460. @add_common_job_options
  461. @add_click_logging_options
  462. @PublicAPI(stability="stable")
  463. def logs(
  464. address: Optional[str],
  465. job_id: str,
  466. follow: bool,
  467. headers: Optional[str],
  468. verify: Union[bool, str],
  469. ):
  470. """Gets the logs of a job.
  471. Example:
  472. `ray job logs <my_job_id>`
  473. """
  474. client = _get_sdk_client(address, headers=headers, verify=verify)
  475. sdk_version = client.get_version()
  476. # sdk version 0 did not have log streaming
  477. if follow:
  478. if int(sdk_version) > 0:
  479. get_or_create_event_loop().run_until_complete(_tail_logs(client, job_id))
  480. else:
  481. cli_logger.warning(
  482. "Tailing logs is not enabled for the Jobs SDK client version "
  483. f"{sdk_version}. Please upgrade Ray to latest version "
  484. "for this feature."
  485. )
  486. else:
  487. # Set no_format to True because the logs may have unescaped "{" and "}"
  488. # and the CLILogger calls str.format().
  489. cli_logger.print(client.get_job_logs(job_id), end="", no_format=True)
  490. @job_cli_group.command()
  491. @click.option(
  492. "--address",
  493. type=str,
  494. default=None,
  495. required=False,
  496. help=(
  497. "Address of the Ray cluster to connect to. Can also be specified "
  498. "using the RAY_API_SERVER_ADDRESS environment variable (falls back to RAY_ADDRESS)."
  499. ),
  500. )
  501. @add_common_job_options
  502. @add_click_logging_options
  503. @PublicAPI(stability="stable")
  504. def list(address: Optional[str], headers: Optional[str], verify: Union[bool, str]):
  505. """Lists all running jobs and their information.
  506. Example:
  507. `ray job list`
  508. """
  509. client = _get_sdk_client(address, headers=headers, verify=verify)
  510. # Set no_format to True because the logs may have unescaped "{" and "}"
  511. # and the CLILogger calls str.format().
  512. cli_logger.print(pprint.pformat(client.list_jobs()), no_format=True)