sdk.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. import dataclasses
  2. import logging
  3. from typing import Any, AsyncIterator, Dict, List, Optional, Union
  4. import packaging.version
  5. import ray
  6. from ray.dashboard.modules.dashboard_sdk import SubmissionClient
  7. from ray.dashboard.modules.job.common import (
  8. JobDeleteResponse,
  9. JobLogsResponse,
  10. JobStatus,
  11. JobStopResponse,
  12. JobSubmitRequest,
  13. JobSubmitResponse,
  14. )
  15. from ray.dashboard.modules.job.pydantic_models import JobDetails
  16. from ray.dashboard.modules.job.utils import strip_keys_with_value_none
  17. from ray.dashboard.utils import get_address_for_submission_client
  18. from ray.runtime_env import RuntimeEnv
  19. from ray.runtime_env.runtime_env import _validate_no_local_paths
  20. from ray.util.annotations import PublicAPI
  21. try:
  22. import aiohttp
  23. import requests
  24. except ImportError:
  25. aiohttp = None
  26. requests = None
  27. logger = logging.getLogger(__name__)
  28. logger.setLevel(logging.INFO)
  29. class JobSubmissionClient(SubmissionClient):
  30. """A local client for submitting and interacting with jobs on a remote cluster.
  31. Submits requests over HTTP to the job server on the cluster using the REST API.
  32. Args:
  33. address: Either (1) the address of the Ray cluster, or (2) the HTTP address
  34. of the dashboard server on the head node, e.g. "http://<head-node-ip>:8265".
  35. In case (1) it must be specified as an address that can be passed to
  36. ray.init(), e.g. a Ray Client address (ray://<head_node_host>:10001),
  37. or "auto", or "localhost:<port>". If unspecified, will try to connect to
  38. a running local Ray cluster. This argument is always overridden by the
  39. RAY_API_SERVER_ADDRESS or RAY_ADDRESS environment variable.
  40. create_cluster_if_needed: Indicates whether the cluster at the specified
  41. address needs to already be running. Ray doesn't start a cluster
  42. before interacting with jobs, but third-party job managers may do so.
  43. cookies: Cookies to use when sending requests to the HTTP job server.
  44. metadata: Arbitrary metadata to store along with all jobs. New metadata
  45. specified per job will be merged with the global metadata provided here
  46. via a simple dict update.
  47. headers: Headers to use when sending requests to the HTTP job server, used
  48. for cases like authentication to a remote cluster.
  49. verify: Boolean indication to verify the server's TLS certificate or a path to
  50. a file or directory of trusted certificates. Default: True.
  51. """
  52. def __init__(
  53. self,
  54. address: Optional[str] = None,
  55. create_cluster_if_needed: bool = False,
  56. cookies: Optional[Dict[str, Any]] = None,
  57. metadata: Optional[Dict[str, Any]] = None,
  58. headers: Optional[Dict[str, Any]] = None,
  59. verify: Optional[Union[str, bool]] = True,
  60. ):
  61. self._client_ray_version = ray.__version__
  62. """Initialize a JobSubmissionClient and check the connection to the cluster."""
  63. if requests is None:
  64. raise RuntimeError(
  65. "The Ray jobs CLI & SDK require the ray[default] "
  66. "installation: `pip install 'ray[default]'`"
  67. )
  68. # Check types of arguments
  69. if address is not None and not isinstance(address, str):
  70. raise TypeError(f"address must be a string, got {type(address)}")
  71. if not isinstance(create_cluster_if_needed, bool):
  72. raise TypeError(
  73. f"create_cluster_if_needed must be a bool, got"
  74. f" {type(create_cluster_if_needed)}"
  75. )
  76. if cookies is not None and not isinstance(cookies, dict):
  77. raise TypeError(f"cookies must be a dict, got {type(cookies)}")
  78. if metadata is not None and not isinstance(metadata, dict):
  79. raise TypeError(f"metadata must be a dict, got {type(metadata)}")
  80. if headers is not None and not isinstance(headers, dict):
  81. raise TypeError(f"headers must be a dict, got {type(headers)}")
  82. if not (isinstance(verify, str) or isinstance(verify, bool)):
  83. raise TypeError(f"verify must be a str or bool, got {type(verify)}")
  84. api_server_url = get_address_for_submission_client(address)
  85. super().__init__(
  86. address=api_server_url,
  87. create_cluster_if_needed=create_cluster_if_needed,
  88. cookies=cookies,
  89. metadata=metadata,
  90. headers=headers,
  91. verify=verify,
  92. )
  93. self._check_connection_and_version(
  94. min_version="1.9",
  95. version_error_message="Jobs API is not supported on the Ray "
  96. "cluster. Please ensure the cluster is "
  97. "running Ray 1.9 or higher.",
  98. )
  99. # In ray>=2.0, the client sends the new kwarg `submission_id` to the server
  100. # upon every job submission, which causes servers with ray<2.0 to error.
  101. if packaging.version.parse(self._client_ray_version) > packaging.version.parse(
  102. "2.0"
  103. ):
  104. self._check_connection_and_version(
  105. min_version="2.0",
  106. version_error_message=f"Client Ray version {self._client_ray_version} "
  107. "is not compatible with the Ray cluster. Please ensure the cluster is "
  108. "running Ray 2.0 or higher or downgrade the client Ray version.",
  109. )
  110. @PublicAPI(stability="stable")
  111. def submit_job(
  112. self,
  113. *,
  114. entrypoint: str,
  115. job_id: Optional[str] = None,
  116. runtime_env: Optional[Dict[str, Any]] = None,
  117. metadata: Optional[Dict[str, str]] = None,
  118. submission_id: Optional[str] = None,
  119. entrypoint_num_cpus: Optional[Union[int, float]] = None,
  120. entrypoint_num_gpus: Optional[Union[int, float]] = None,
  121. entrypoint_memory: Optional[int] = None,
  122. entrypoint_resources: Optional[Dict[str, float]] = None,
  123. entrypoint_label_selector: Optional[Dict[str, str]] = None,
  124. ) -> str:
  125. """Submit and execute a job asynchronously.
  126. When a job is submitted, it runs once to completion or failure. Retries or
  127. different runs with different parameters should be handled by the
  128. submitter. Jobs are bound to the lifetime of a Ray cluster, so if the
  129. cluster goes down, all running jobs on that cluster will be terminated.
  130. Example:
  131. >>> from ray.job_submission import JobSubmissionClient
  132. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  133. >>> client.submit_job( # doctest: +SKIP
  134. ... entrypoint="python script.py",
  135. ... runtime_env={
  136. ... "working_dir": "./",
  137. ... "pip": ["requests==2.26.0"]
  138. ... }
  139. ... ) # doctest: +SKIP
  140. 'raysubmit_4LamXRuQpYdSMg7J'
  141. Args:
  142. entrypoint: The shell command to run for this job.
  143. job_id: DEPRECATED. This has been renamed to submission_id.
  144. runtime_env: The runtime environment to install and run this job in.
  145. metadata: Arbitrary data to store along with this job.
  146. submission_id: A unique ID for this job.
  147. entrypoint_num_cpus: The quantity of CPU cores to reserve for the execution
  148. of the entrypoint command, separately from any tasks or actors launched
  149. by it. Defaults to 0.
  150. entrypoint_num_gpus: The quantity of GPUs to reserve for the execution
  151. of the entrypoint command, separately from any tasks or actors launched
  152. by it. Defaults to 0.
  153. entrypoint_memory: The quantity of memory to reserve for the
  154. execution of the entrypoint command, separately from any tasks or
  155. actors launched by it. Defaults to 0.
  156. entrypoint_resources: The quantity of custom resources to reserve for the
  157. execution of the entrypoint command, separately from any tasks or
  158. actors launched by it.
  159. entrypoint_label_selector: Label selector for the entrypoint command.
  160. Returns:
  161. The submission ID of the submitted job. If not specified,
  162. this is a randomly generated unique ID.
  163. Raises:
  164. RuntimeError: If the request to the job server fails, or if the specified
  165. submission_id has already been used by a job on this cluster.
  166. """
  167. if job_id:
  168. logger.warning(
  169. "job_id kwarg is deprecated. Please use submission_id instead."
  170. )
  171. if (
  172. entrypoint_num_cpus
  173. or entrypoint_num_gpus
  174. or entrypoint_resources
  175. or entrypoint_label_selector
  176. ):
  177. self._check_connection_and_version(
  178. min_version="2.2",
  179. version_error_message="`entrypoint_num_cpus`, `entrypoint_num_gpus`, "
  180. "`entrypoint_resources`, and `entrypoint_label_selector` kwargs "
  181. "are not supported on the Ray cluster. Please ensure the cluster is "
  182. "running Ray 2.2 or higher.",
  183. )
  184. if entrypoint_memory:
  185. self._check_connection_and_version(
  186. min_version="2.8",
  187. version_error_message="`entrypoint_memory` kwarg "
  188. "is not supported on the Ray cluster. Please ensure the cluster is "
  189. "running Ray 2.8 or higher.",
  190. )
  191. runtime_env = runtime_env or {}
  192. metadata = metadata or {}
  193. metadata.update(self._default_metadata)
  194. self._upload_working_dir_if_needed(runtime_env)
  195. self._upload_py_modules_if_needed(runtime_env)
  196. # Verify worker_process_setup_hook type.
  197. setup_hook = runtime_env.get("worker_process_setup_hook")
  198. if setup_hook and not isinstance(setup_hook, str):
  199. raise ValueError(
  200. f"Invalid type {type(setup_hook)} for `worker_process_setup_hook`. "
  201. "When a job submission API is used, `worker_process_setup_hook` "
  202. "only allows a string type (module name). "
  203. "Specify `worker_process_setup_hook` via "
  204. "ray.init within a driver to use a `Callable` type. "
  205. )
  206. # Run the RuntimeEnv constructor to parse local pip/conda requirements files.
  207. runtime_env = RuntimeEnv(**runtime_env)
  208. _validate_no_local_paths(runtime_env)
  209. runtime_env = runtime_env.to_dict()
  210. submission_id = submission_id or job_id
  211. req = JobSubmitRequest(
  212. entrypoint=entrypoint,
  213. submission_id=submission_id,
  214. runtime_env=runtime_env,
  215. metadata=metadata,
  216. entrypoint_num_cpus=entrypoint_num_cpus,
  217. entrypoint_num_gpus=entrypoint_num_gpus,
  218. entrypoint_memory=entrypoint_memory,
  219. entrypoint_resources=entrypoint_resources,
  220. entrypoint_label_selector=entrypoint_label_selector,
  221. )
  222. # Remove keys with value None so that new clients with new optional fields
  223. # are still compatible with older servers. This is also done on the server,
  224. # but we do it here as well to be extra defensive.
  225. json_data = strip_keys_with_value_none(dataclasses.asdict(req))
  226. logger.debug(f"Submitting job with submission_id={submission_id}.")
  227. r = self._do_request("POST", "/api/jobs/", json_data=json_data)
  228. if r.status_code == 200:
  229. return JobSubmitResponse(**r.json()).submission_id
  230. else:
  231. self._raise_error(r)
  232. @PublicAPI(stability="stable")
  233. def stop_job(
  234. self,
  235. job_id: str,
  236. ) -> bool:
  237. """Request a job to exit asynchronously.
  238. Attempts to terminate process first, then kills process after timeout.
  239. Example:
  240. >>> from ray.job_submission import JobSubmissionClient
  241. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  242. >>> sub_id = client.submit_job(entrypoint="sleep 10") # doctest: +SKIP
  243. >>> client.stop_job(sub_id) # doctest: +SKIP
  244. True
  245. Args:
  246. job_id: The job ID or submission ID for the job to be stopped.
  247. Returns:
  248. True if the job was running, otherwise False.
  249. Raises:
  250. RuntimeError: If the job does not exist or if the request to the
  251. job server fails.
  252. """
  253. logger.debug(f"Stopping job with job_id={job_id}.")
  254. r = self._do_request("POST", f"/api/jobs/{job_id}/stop")
  255. if r.status_code == 200:
  256. return JobStopResponse(**r.json()).stopped
  257. else:
  258. self._raise_error(r)
  259. @PublicAPI(stability="stable")
  260. def delete_job(
  261. self,
  262. job_id: str,
  263. ) -> bool:
  264. """Delete a job in a terminal state and all of its associated data.
  265. If the job is not already in a terminal state, raises an error.
  266. This does not delete the job logs from disk.
  267. Submitting a job with the same submission ID as a previously
  268. deleted job is not supported and may lead to unexpected behavior.
  269. Example:
  270. >>> from ray.job_submission import JobSubmissionClient
  271. >>> client = JobSubmissionClient() # doctest: +SKIP
  272. >>> job_id = client.submit_job(entrypoint="echo hello") # doctest: +SKIP
  273. >>> client.delete_job(job_id) # doctest: +SKIP
  274. True
  275. Args:
  276. job_id: submission ID for the job to be deleted.
  277. Returns:
  278. True if the job was deleted, otherwise False.
  279. Raises:
  280. RuntimeError: If the job does not exist, if the request to the
  281. job server fails, or if the job is not in a terminal state.
  282. """
  283. logger.debug(f"Deleting job with job_id={job_id}.")
  284. r = self._do_request("DELETE", f"/api/jobs/{job_id}")
  285. if r.status_code == 200:
  286. return JobDeleteResponse(**r.json()).deleted
  287. else:
  288. self._raise_error(r)
  289. @PublicAPI(stability="stable")
  290. def get_job_info(
  291. self,
  292. job_id: str,
  293. ) -> JobDetails:
  294. """Get the latest status and other information associated with a job.
  295. Example:
  296. >>> from ray.job_submission import JobSubmissionClient
  297. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  298. >>> submission_id = client.submit_job(entrypoint="sleep 1") # doctest: +SKIP
  299. >>> client.get_job_info(submission_id) # doctest: +SKIP
  300. JobDetails(status='SUCCEEDED',
  301. job_id='03000000', type='submission',
  302. submission_id='raysubmit_4LamXRuQpYdSMg7J',
  303. message='Job finished successfully.', error_type=None,
  304. start_time=1647388711, end_time=1647388712, metadata={}, runtime_env={})
  305. Args:
  306. job_id: The job ID or submission ID of the job whose information
  307. is being requested.
  308. Returns:
  309. The JobDetails for the job.
  310. Raises:
  311. RuntimeError: If the job does not exist or if the request to the
  312. job server fails.
  313. """
  314. r = self._do_request("GET", f"/api/jobs/{job_id}")
  315. if r.status_code == 200:
  316. if JobDetails is None:
  317. raise RuntimeError(
  318. "The Ray jobs CLI & SDK require the ray[default] "
  319. "installation: `pip install 'ray[default]'`"
  320. )
  321. else:
  322. return JobDetails(**r.json())
  323. else:
  324. self._raise_error(r)
  325. @PublicAPI(stability="stable")
  326. def list_jobs(self) -> List[JobDetails]:
  327. """List all jobs along with their status and other information.
  328. Lists all jobs that have ever run on the cluster, including jobs that are
  329. currently running and jobs that are no longer running.
  330. Example:
  331. >>> from ray.job_submission import JobSubmissionClient
  332. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  333. >>> client.submit_job(entrypoint="echo hello") # doctest: +SKIP
  334. >>> client.submit_job(entrypoint="sleep 2") # doctest: +SKIP
  335. >>> client.list_jobs() # doctest: +SKIP
  336. [JobDetails(status='SUCCEEDED',
  337. job_id='03000000', type='submission',
  338. submission_id='raysubmit_4LamXRuQpYdSMg7J',
  339. message='Job finished successfully.', error_type=None,
  340. start_time=1647388711, end_time=1647388712, metadata={}, runtime_env={}),
  341. JobDetails(status='RUNNING',
  342. job_id='04000000', type='submission',
  343. submission_id='raysubmit_1dxCeNvG1fCMVNHG',
  344. message='Job is currently running.', error_type=None,
  345. start_time=1647454832, end_time=None, metadata={}, runtime_env={})]
  346. Returns:
  347. A list of JobDetails containing the job status and other information.
  348. Raises:
  349. RuntimeError: If the request to the job server fails.
  350. """
  351. r = self._do_request("GET", "/api/jobs/")
  352. if r.status_code == 200:
  353. jobs_info_json = r.json()
  354. jobs_info = [
  355. JobDetails(**job_info_json) for job_info_json in jobs_info_json
  356. ]
  357. return jobs_info
  358. else:
  359. self._raise_error(r)
  360. @PublicAPI(stability="stable")
  361. def get_job_status(self, job_id: str) -> JobStatus:
  362. """Get the most recent status of a job.
  363. Example:
  364. >>> from ray.job_submission import JobSubmissionClient
  365. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  366. >>> client.submit_job(entrypoint="echo hello") # doctest: +SKIP
  367. >>> client.get_job_status("raysubmit_4LamXRuQpYdSMg7J") # doctest: +SKIP
  368. 'SUCCEEDED'
  369. Args:
  370. job_id: The job ID or submission ID of the job whose status is being
  371. requested.
  372. Returns:
  373. The JobStatus of the job.
  374. Raises:
  375. RuntimeError: If the job does not exist or if the request to the
  376. job server fails.
  377. """
  378. return self.get_job_info(job_id).status
  379. @PublicAPI(stability="stable")
  380. def get_job_logs(self, job_id: str) -> str:
  381. """Get all logs produced by a job.
  382. Example:
  383. >>> from ray.job_submission import JobSubmissionClient
  384. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  385. >>> sub_id = client.submit_job(entrypoint="echo hello") # doctest: +SKIP
  386. >>> client.get_job_logs(sub_id) # doctest: +SKIP
  387. 'hello\\n'
  388. Args:
  389. job_id: The job ID or submission ID of the job whose logs are being
  390. requested.
  391. Returns:
  392. A string containing the full logs of the job.
  393. Raises:
  394. RuntimeError: If the job does not exist or if the request to the
  395. job server fails.
  396. """
  397. r = self._do_request("GET", f"/api/jobs/{job_id}/logs")
  398. if r.status_code == 200:
  399. return JobLogsResponse(**r.json()).logs
  400. else:
  401. self._raise_error(r)
  402. @PublicAPI(stability="stable")
  403. async def tail_job_logs(self, job_id: str) -> AsyncIterator[str]:
  404. """Get an iterator that follows the logs of a job.
  405. Example:
  406. >>> from ray.job_submission import JobSubmissionClient
  407. >>> client = JobSubmissionClient("http://127.0.0.1:8265") # doctest: +SKIP
  408. >>> submission_id = client.submit_job( # doctest: +SKIP
  409. ... entrypoint="echo hi && sleep 5 && echo hi2")
  410. >>> async for lines in client.tail_job_logs( # doctest: +SKIP
  411. ... 'raysubmit_Xe7cvjyGJCyuCvm2'):
  412. ... print(lines, end="") # doctest: +SKIP
  413. hi
  414. hi2
  415. Args:
  416. job_id: The job ID or submission ID of the job whose logs are being
  417. requested.
  418. Returns:
  419. The iterator.
  420. Raises:
  421. RuntimeError: If the job does not exist, if the request to the
  422. job server fails, or if the connection closes unexpectedly
  423. before the job reaches a terminal state.
  424. """
  425. async with aiohttp.ClientSession(
  426. cookies=self._cookies, headers=self._headers
  427. ) as session:
  428. ws = await session.ws_connect(
  429. f"{self._address}/api/jobs/{job_id}/logs/tail",
  430. headers=self._headers,
  431. ssl=self._ssl_context,
  432. )
  433. while True:
  434. msg = await ws.receive()
  435. if msg.type == aiohttp.WSMsgType.TEXT:
  436. yield msg.data
  437. elif msg.type == aiohttp.WSMsgType.CLOSED:
  438. logger.debug(
  439. f"WebSocket closed for job {job_id} with close code {ws.close_code}"
  440. )
  441. if ws.close_code == aiohttp.WSCloseCode.ABNORMAL_CLOSURE:
  442. raise RuntimeError(
  443. f"WebSocket connection closed unexpectedly with close code {ws.close_code}"
  444. )
  445. break
  446. elif msg.type == aiohttp.WSMsgType.ERROR:
  447. # Old Ray versions (<=2.0.1) may send ERROR on connection close
  448. if self._server_ray_version is not None and packaging.version.parse(
  449. self._server_ray_version
  450. ) > packaging.version.parse("2.0.1"):
  451. raise RuntimeError(
  452. f"WebSocket error for job {job_id}: {ws.exception()}"
  453. )
  454. else:
  455. logger.debug(
  456. f"WebSocket error for job {job_id}, treating as normal close. Err: {ws.exception()}"
  457. )
  458. break