dashboard_sdk.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import dataclasses
  2. import importlib
  3. import json
  4. import logging
  5. import os
  6. import ssl
  7. import tempfile
  8. from pathlib import Path
  9. from typing import Any, Dict, List, Optional, Union
  10. import packaging.version
  11. import yaml
  12. import ray
  13. from ray._private.authentication.http_token_authentication import (
  14. format_authentication_http_error,
  15. get_auth_headers_if_auth_enabled,
  16. )
  17. from ray._private.runtime_env.packaging import (
  18. create_package,
  19. get_uri_for_directory,
  20. get_uri_for_package,
  21. )
  22. from ray._private.runtime_env.py_modules import upload_py_modules_if_needed
  23. from ray._private.runtime_env.working_dir import upload_working_dir_if_needed
  24. from ray._private.utils import split_address
  25. from ray.autoscaler._private.cli_logger import cli_logger
  26. from ray.dashboard.modules.job.common import uri_to_http_components
  27. from ray.exceptions import AuthenticationError
  28. from ray.util.annotations import DeveloperAPI, PublicAPI
  29. try:
  30. import requests
  31. except ImportError:
  32. requests = None
  33. logger = logging.getLogger(__name__)
  34. logger.setLevel(logging.INFO)
  35. # By default, connect to local cluster.
  36. DEFAULT_DASHBOARD_ADDRESS = "http://localhost:8265"
  37. def parse_runtime_env_args(
  38. runtime_env: Optional[str] = None,
  39. runtime_env_json: Optional[str] = None,
  40. working_dir: Optional[str] = None,
  41. ):
  42. """
  43. Generates a runtime_env dictionary using `runtime_env`, `runtime_env_json`,
  44. and `working_dir` CLI options. Only one of `runtime_env` or
  45. `runtime_env_json` may be defined. `working_dir` overwrites the
  46. `working_dir` from any other option.
  47. """
  48. final_runtime_env = {}
  49. if runtime_env is not None:
  50. if runtime_env_json is not None:
  51. raise ValueError(
  52. "Only one of --runtime_env and --runtime-env-json can be provided."
  53. )
  54. with open(runtime_env, "r") as f:
  55. final_runtime_env = yaml.safe_load(f)
  56. elif runtime_env_json is not None:
  57. final_runtime_env = json.loads(runtime_env_json)
  58. if working_dir is not None:
  59. if "working_dir" in final_runtime_env:
  60. cli_logger.warning(
  61. "Overriding runtime_env working_dir with --working-dir option"
  62. )
  63. final_runtime_env["working_dir"] = working_dir
  64. return final_runtime_env
  65. @dataclasses.dataclass
  66. class ClusterInfo:
  67. address: str
  68. cookies: Optional[Dict[str, Any]] = None
  69. metadata: Optional[Dict[str, Any]] = None
  70. headers: Optional[Dict[str, Any]] = None
  71. # TODO (shrekris-anyscale): renaming breaks compatibility, do NOT rename
  72. def get_job_submission_client_cluster_info(
  73. address: str,
  74. # For backwards compatibility
  75. *,
  76. # only used in importlib case in parse_cluster_info, but needed
  77. # in function signature.
  78. create_cluster_if_needed: Optional[bool] = False,
  79. cookies: Optional[Dict[str, Any]] = None,
  80. metadata: Optional[Dict[str, Any]] = None,
  81. headers: Optional[Dict[str, Any]] = None,
  82. _use_tls: Optional[bool] = False,
  83. ) -> ClusterInfo:
  84. """Get address, cookies, and metadata used for SubmissionClient.
  85. If no port is specified in `address`, the Ray dashboard default will be
  86. inserted.
  87. Args:
  88. address: Address without the module prefix that is passed
  89. to SubmissionClient.
  90. create_cluster_if_needed: Indicates whether the cluster
  91. of the address returned needs to be running. Ray doesn't
  92. start a cluster before interacting with jobs, but other
  93. implementations may do so.
  94. Returns:
  95. ClusterInfo object consisting of address, cookies, and metadata
  96. for SubmissionClient to use.
  97. """
  98. scheme = "https" if _use_tls else "http"
  99. return ClusterInfo(
  100. address=f"{scheme}://{address}",
  101. cookies=cookies,
  102. metadata=metadata,
  103. headers=headers,
  104. )
  105. def parse_cluster_info(
  106. address: Optional[str] = None,
  107. create_cluster_if_needed: bool = False,
  108. cookies: Optional[Dict[str, Any]] = None,
  109. metadata: Optional[Dict[str, Any]] = None,
  110. headers: Optional[Dict[str, Any]] = None,
  111. ) -> ClusterInfo:
  112. """Create a cluster if needed and return its address, cookies, and metadata."""
  113. if address is None:
  114. if (
  115. ray.is_initialized()
  116. and ray._private.worker.global_worker.node.address_info["webui_url"]
  117. is not None
  118. ):
  119. address = (
  120. "http://"
  121. f"{ray._private.worker.global_worker.node.address_info['webui_url']}"
  122. )
  123. logger.info(
  124. f"No address provided but Ray is running; using address {address}."
  125. )
  126. else:
  127. logger.info(
  128. f"No address provided, defaulting to {DEFAULT_DASHBOARD_ADDRESS}."
  129. )
  130. address = DEFAULT_DASHBOARD_ADDRESS
  131. if address == "auto":
  132. raise ValueError("Internal error: unexpected address 'auto'.")
  133. if "://" not in address:
  134. # Default to HTTP.
  135. logger.info(
  136. "No scheme (e.g. 'http://') or module string (e.g. 'ray://') "
  137. f"provided in address {address}, defaulting to HTTP."
  138. )
  139. address = f"http://{address}"
  140. module_string, inner_address = split_address(address)
  141. if module_string == "ray":
  142. raise ValueError(f"Internal error: unexpected Ray Client address {address}.")
  143. # If user passes http(s)://, go through normal parsing.
  144. if module_string in {"http", "https"}:
  145. return get_job_submission_client_cluster_info(
  146. inner_address,
  147. create_cluster_if_needed=create_cluster_if_needed,
  148. cookies=cookies,
  149. metadata=metadata,
  150. headers=headers,
  151. _use_tls=(module_string == "https"),
  152. )
  153. # Try to dynamically import the function to get cluster info.
  154. else:
  155. try:
  156. module = importlib.import_module(module_string)
  157. except Exception:
  158. raise RuntimeError(
  159. f"Module: {module_string} does not exist.\n"
  160. f"This module was parsed from address: {address}"
  161. ) from None
  162. assert "get_job_submission_client_cluster_info" in dir(module), (
  163. f"Module: {module_string} does "
  164. "not have `get_job_submission_client_cluster_info`.\n"
  165. f"This module was parsed from address: {address}"
  166. )
  167. return module.get_job_submission_client_cluster_info(
  168. inner_address,
  169. create_cluster_if_needed=create_cluster_if_needed,
  170. cookies=cookies,
  171. metadata=metadata,
  172. headers=headers,
  173. )
  174. class SubmissionClient:
  175. def __init__(
  176. self,
  177. address: Optional[str] = None,
  178. create_cluster_if_needed: bool = False,
  179. cookies: Optional[Dict[str, Any]] = None,
  180. metadata: Optional[Dict[str, Any]] = None,
  181. headers: Optional[Dict[str, Any]] = None,
  182. verify: Optional[Union[str, bool]] = True,
  183. ):
  184. # Remove any trailing slashes
  185. if address is not None and address.endswith("/"):
  186. address = address.rstrip("/")
  187. logger.debug(
  188. "The submission address cannot contain trailing slashes. Removing "
  189. f'them from the requested submission address of "{address}".'
  190. )
  191. cluster_info = parse_cluster_info(
  192. address, create_cluster_if_needed, cookies, metadata, headers
  193. )
  194. self._address = cluster_info.address
  195. self._cookies = cluster_info.cookies
  196. self._default_metadata = cluster_info.metadata or {}
  197. # Headers used for all requests sent to job server, optional and only
  198. # needed for cases like authentication to remote cluster.
  199. self._headers = cluster_info.headers or {}
  200. self._headers.update(**get_auth_headers_if_auth_enabled(self._headers))
  201. # Set SSL verify parameter for the requests library and create an ssl_context
  202. # object when needed for the aiohttp library.
  203. self._verify = verify
  204. if isinstance(self._verify, str):
  205. if os.path.isdir(self._verify):
  206. cafile, capath = None, self._verify
  207. elif os.path.isfile(self._verify):
  208. cafile, capath = self._verify, None
  209. else:
  210. raise FileNotFoundError(
  211. f"Path to CA certificates: '{self._verify}', does not exist."
  212. )
  213. self._ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
  214. else:
  215. if self._verify is False:
  216. self._ssl_context = False
  217. else:
  218. self._ssl_context = None
  219. self._server_ray_version: Optional[str] = None
  220. def _check_connection_and_version(
  221. self, min_version: str = "1.9", version_error_message: str = None
  222. ):
  223. self._check_connection_and_version_with_url(min_version, version_error_message)
  224. def _check_connection_and_version_with_url(
  225. self,
  226. min_version: str = "1.9",
  227. version_error_message: str = None,
  228. url: str = "/api/version",
  229. ):
  230. if version_error_message is None:
  231. version_error_message = (
  232. f"Please ensure the cluster is running Ray {min_version} or higher."
  233. )
  234. try:
  235. r = self._do_request("GET", url)
  236. if r.status_code == 404:
  237. raise RuntimeError(
  238. "Version check returned 404. " + version_error_message
  239. )
  240. r.raise_for_status()
  241. running_ray_version = r.json()["ray_version"]
  242. self._server_ray_version = running_ray_version
  243. if packaging.version.parse(running_ray_version) < packaging.version.parse(
  244. min_version
  245. ):
  246. raise RuntimeError(
  247. f"Ray version {running_ray_version} is running on the cluster. "
  248. + version_error_message
  249. )
  250. except requests.exceptions.ConnectionError:
  251. raise ConnectionError(
  252. f"Failed to connect to Ray at address: {self._address}."
  253. )
  254. def _raise_error(self, r: "requests.Response"):
  255. raise RuntimeError(
  256. f"Request failed with status code {r.status_code}: {r.text}."
  257. )
  258. def _do_request(
  259. self,
  260. method: str,
  261. endpoint: str,
  262. *,
  263. data: Optional[bytes] = None,
  264. json_data: Optional[dict] = None,
  265. **kwargs,
  266. ) -> "requests.Response":
  267. """Perform the actual HTTP request with authentication error handling.
  268. Keyword arguments other than "cookies", "headers" are forwarded to the
  269. `requests.request()`.
  270. """
  271. url = self._address + endpoint
  272. logger.debug(f"Sending request to {url} with json data: {json_data or {}}.")
  273. response = requests.request(
  274. method,
  275. url,
  276. cookies=self._cookies,
  277. data=data,
  278. json=json_data,
  279. headers=self._headers,
  280. verify=self._verify,
  281. **kwargs,
  282. )
  283. # Check for authentication errors and provide helpful messages
  284. formatted_error = format_authentication_http_error(
  285. response.status_code, response.text
  286. )
  287. if formatted_error:
  288. raise AuthenticationError(formatted_error)
  289. return response
  290. def _package_exists(
  291. self,
  292. package_uri: str,
  293. ) -> bool:
  294. protocol, package_name = uri_to_http_components(package_uri)
  295. r = self._do_request("GET", f"/api/packages/{protocol}/{package_name}")
  296. if r.status_code == 200:
  297. logger.debug(f"Package {package_uri} already exists.")
  298. return True
  299. elif r.status_code == 404:
  300. logger.debug(f"Package {package_uri} does not exist.")
  301. return False
  302. else:
  303. self._raise_error(r)
  304. def _upload_package(
  305. self,
  306. package_uri: str,
  307. package_path: str,
  308. include_gitignore: bool,
  309. include_parent_dir: Optional[bool] = False,
  310. excludes: Optional[List[str]] = None,
  311. is_file: bool = False,
  312. ) -> bool:
  313. logger.info(f"Uploading package {package_uri}.")
  314. with tempfile.TemporaryDirectory() as tmp_dir:
  315. protocol, package_name = uri_to_http_components(package_uri)
  316. if is_file:
  317. package_file = Path(package_path)
  318. else:
  319. package_file = Path(tmp_dir) / package_name
  320. create_package(
  321. package_path,
  322. package_file,
  323. include_gitignore=include_gitignore,
  324. include_parent_dir=include_parent_dir,
  325. excludes=excludes,
  326. )
  327. try:
  328. r = self._do_request(
  329. "PUT",
  330. f"/api/packages/{protocol}/{package_name}",
  331. data=package_file.read_bytes(),
  332. )
  333. if r.status_code != 200:
  334. self._raise_error(r)
  335. finally:
  336. # If the package is a user's existing file, don't delete it.
  337. if not is_file:
  338. package_file.unlink()
  339. def _upload_package_if_needed(
  340. self,
  341. package_path: str,
  342. include_gitignore: bool,
  343. include_parent_dir: bool = False,
  344. excludes: Optional[List[str]] = None,
  345. is_file: bool = False,
  346. ) -> str:
  347. if is_file:
  348. package_uri = get_uri_for_package(Path(package_path))
  349. else:
  350. package_uri = get_uri_for_directory(
  351. package_path, include_gitignore, excludes=excludes
  352. )
  353. if not self._package_exists(package_uri):
  354. self._upload_package(
  355. package_uri,
  356. package_path,
  357. include_gitignore=include_gitignore,
  358. include_parent_dir=include_parent_dir,
  359. excludes=excludes,
  360. is_file=is_file,
  361. )
  362. else:
  363. logger.info(f"Package {package_uri} already exists, skipping upload.")
  364. return package_uri
  365. def _upload_working_dir_if_needed(self, runtime_env: Dict[str, Any]):
  366. from ray._private.ray_constants import RAY_RUNTIME_ENV_IGNORE_GITIGNORE
  367. # Determine whether to respect .gitignore files based on environment variable
  368. # Default is True (respect .gitignore). Set to False if env var is "1".
  369. include_gitignore = os.environ.get(RAY_RUNTIME_ENV_IGNORE_GITIGNORE, "0") != "1"
  370. def _upload_fn(working_dir, excludes, is_file=False):
  371. self._upload_package_if_needed(
  372. working_dir,
  373. include_gitignore=include_gitignore,
  374. include_parent_dir=False,
  375. excludes=excludes,
  376. is_file=is_file,
  377. )
  378. upload_working_dir_if_needed(
  379. runtime_env, include_gitignore=include_gitignore, upload_fn=_upload_fn
  380. )
  381. def _upload_py_modules_if_needed(self, runtime_env: Dict[str, Any]):
  382. from ray._private.ray_constants import RAY_RUNTIME_ENV_IGNORE_GITIGNORE
  383. # Determine whether to respect .gitignore files based on environment variable
  384. # Default is True (respect .gitignore). Set to False if env var is "1".
  385. include_gitignore = os.environ.get(RAY_RUNTIME_ENV_IGNORE_GITIGNORE, "0") != "1"
  386. def _upload_fn(module_path, excludes, is_file=False):
  387. self._upload_package_if_needed(
  388. module_path,
  389. include_gitignore=include_gitignore,
  390. include_parent_dir=True,
  391. excludes=excludes,
  392. is_file=is_file,
  393. )
  394. upload_py_modules_if_needed(
  395. runtime_env, include_gitignore=include_gitignore, upload_fn=_upload_fn
  396. )
  397. @PublicAPI(stability="beta")
  398. def get_version(self) -> str:
  399. r = self._do_request("GET", "/api/version")
  400. if r.status_code == 200:
  401. return r.json().get("version")
  402. else:
  403. self._raise_error(r)
  404. @DeveloperAPI
  405. def get_address(self) -> str:
  406. return self._address