conda.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import hashlib
  2. import json
  3. import logging
  4. import os
  5. import runpy
  6. import shutil
  7. import subprocess
  8. import sys
  9. from pathlib import Path
  10. from typing import Any, Dict, List, Optional
  11. import yaml
  12. from filelock import FileLock
  13. import ray
  14. from ray._common.utils import (
  15. get_or_create_event_loop,
  16. try_to_create_directory,
  17. )
  18. from ray._private.runtime_env.conda_utils import (
  19. create_conda_env_if_needed,
  20. delete_conda_env,
  21. get_conda_activate_commands,
  22. get_conda_envs,
  23. get_conda_info_json,
  24. )
  25. from ray._private.runtime_env.context import RuntimeEnvContext
  26. from ray._private.runtime_env.packaging import Protocol, parse_uri
  27. from ray._private.runtime_env.plugin import RuntimeEnvPlugin
  28. from ray._private.runtime_env.validation import parse_and_validate_conda
  29. from ray._private.utils import (
  30. get_directory_size_bytes,
  31. get_master_wheel_url,
  32. get_release_wheel_url,
  33. get_wheel_filename,
  34. )
  35. default_logger = logging.getLogger(__name__)
  36. _WIN32 = os.name == "nt"
  37. def _resolve_current_ray_path() -> str:
  38. # When ray is built from source with pip install -e,
  39. # ray.__file__ returns .../python/ray/__init__.py and this function returns
  40. # ".../python".
  41. # When ray is installed from a prebuilt binary, ray.__file__ returns
  42. # .../site-packages/ray/__init__.py and this function returns
  43. # ".../site-packages".
  44. return os.path.split(os.path.split(ray.__file__)[0])[0]
  45. def _get_ray_setup_spec():
  46. """Find the Ray setup_spec from the currently running Ray.
  47. This function works even when Ray is built from source with pip install -e.
  48. """
  49. ray_source_python_path = _resolve_current_ray_path()
  50. setup_py_path = os.path.join(ray_source_python_path, "setup.py")
  51. return runpy.run_path(setup_py_path)["setup_spec"]
  52. def _resolve_install_from_source_ray_dependencies():
  53. """Find the Ray dependencies when Ray is installed from source."""
  54. deps = (
  55. _get_ray_setup_spec().install_requires + _get_ray_setup_spec().extras["default"]
  56. )
  57. # Remove duplicates
  58. return list(set(deps))
  59. def _inject_ray_to_conda_site(
  60. conda_path, logger: Optional[logging.Logger] = default_logger
  61. ):
  62. """Write the current Ray site package directory to a new site"""
  63. if _WIN32:
  64. python_binary = os.path.join(conda_path, "python")
  65. else:
  66. python_binary = os.path.join(conda_path, "bin/python")
  67. site_packages_path = (
  68. subprocess.check_output(
  69. [
  70. python_binary,
  71. "-c",
  72. "import sysconfig; print(sysconfig.get_paths()['purelib'])",
  73. ]
  74. )
  75. .decode()
  76. .strip()
  77. )
  78. ray_path = _resolve_current_ray_path()
  79. logger.warning(
  80. f"Injecting {ray_path} to environment site-packages {site_packages_path} "
  81. "because _inject_current_ray flag is on."
  82. )
  83. maybe_ray_dir = os.path.join(site_packages_path, "ray")
  84. if os.path.isdir(maybe_ray_dir):
  85. logger.warning(f"Replacing existing ray installation with {ray_path}")
  86. shutil.rmtree(maybe_ray_dir)
  87. # See usage of *.pth file at
  88. # https://docs.python.org/3/library/site.html
  89. with open(os.path.join(site_packages_path, "ray_shared.pth"), "w") as f:
  90. f.write(ray_path)
  91. def _current_py_version():
  92. return ".".join(map(str, sys.version_info[:3])) # like 3.6.10
  93. def current_ray_pip_specifier(
  94. logger: Optional[logging.Logger] = default_logger,
  95. ) -> Optional[str]:
  96. """The pip requirement specifier for the running version of Ray.
  97. Returns:
  98. A string which can be passed to `pip install` to install the
  99. currently running Ray version, or None if running on a version
  100. built from source locally (likely if you are developing Ray).
  101. Examples:
  102. Returns "https://s3-us-west-2.amazonaws.com/ray-wheels/[..].whl"
  103. if running a stable release, a nightly or a specific commit
  104. """
  105. if os.environ.get("RAY_CI_POST_WHEEL_TESTS"):
  106. # Running in Buildkite CI after the wheel has been built.
  107. # Wheels are at in the ray/.whl directory, but use relative path to
  108. # allow for testing locally if needed.
  109. return os.path.join(
  110. Path(ray.__file__).resolve().parents[2], ".whl", get_wheel_filename()
  111. )
  112. elif ray.__commit__ == "{{RAY_COMMIT_SHA}}":
  113. # Running on a version built from source locally.
  114. if os.environ.get("RAY_RUNTIME_ENV_LOCAL_DEV_MODE") != "1":
  115. logger.warning(
  116. "Current Ray version could not be detected, most likely "
  117. "because you have manually built Ray from source. To use "
  118. "runtime_env in this case, set the environment variable "
  119. "RAY_RUNTIME_ENV_LOCAL_DEV_MODE=1."
  120. )
  121. return None
  122. elif "dev" in ray.__version__:
  123. # Running on a nightly wheel.
  124. return get_master_wheel_url()
  125. else:
  126. return get_release_wheel_url()
  127. def inject_dependencies(
  128. conda_dict: Dict[Any, Any],
  129. py_version: str,
  130. pip_dependencies: Optional[List[str]] = None,
  131. ) -> Dict[Any, Any]:
  132. """Add Ray, Python and (optionally) extra pip dependencies to a conda dict.
  133. Args:
  134. conda_dict: A dict representing the JSON-serialized conda
  135. environment YAML file. This dict will be modified and returned.
  136. py_version: A string representing a Python version to inject
  137. into the conda dependencies, e.g. "3.7.7"
  138. pip_dependencies (List[str]): A list of pip dependencies that
  139. will be prepended to the list of pip dependencies in
  140. the conda dict. If the conda dict does not already have a "pip"
  141. field, one will be created.
  142. Returns:
  143. The modified dict. (Note: the input argument conda_dict is modified
  144. and returned.)
  145. """
  146. if pip_dependencies is None:
  147. pip_dependencies = []
  148. if conda_dict.get("dependencies") is None:
  149. conda_dict["dependencies"] = []
  150. # Inject Python dependency.
  151. deps = conda_dict["dependencies"]
  152. # Add current python dependency. If the user has already included a
  153. # python version dependency, conda will raise a readable error if the two
  154. # are incompatible, e.g:
  155. # ResolvePackageNotFound: - python[version='3.5.*,>=3.6']
  156. deps.append(f"python={py_version}")
  157. if "pip" not in deps:
  158. deps.append("pip")
  159. # Insert pip dependencies.
  160. found_pip_dict = False
  161. for dep in deps:
  162. if isinstance(dep, dict) and dep.get("pip") and isinstance(dep["pip"], list):
  163. dep["pip"] = pip_dependencies + dep["pip"]
  164. found_pip_dict = True
  165. break
  166. if not found_pip_dict:
  167. deps.append({"pip": pip_dependencies})
  168. return conda_dict
  169. def _get_conda_env_hash(conda_dict: Dict) -> str:
  170. # Set `sort_keys=True` so that different orderings yield the same hash.
  171. serialized_conda_spec = json.dumps(conda_dict, sort_keys=True)
  172. hash = hashlib.sha1(serialized_conda_spec.encode("utf-8")).hexdigest()
  173. return hash
  174. def get_uri(runtime_env: Dict) -> Optional[str]:
  175. """Return `"conda://<hashed_dependencies>"`, or None if no GC required."""
  176. conda = runtime_env.get("conda")
  177. if conda is not None:
  178. if isinstance(conda, str):
  179. # User-preinstalled conda env. We don't garbage collect these, so
  180. # we don't track them with URIs.
  181. uri = None
  182. elif isinstance(conda, dict):
  183. uri = f"conda://{_get_conda_env_hash(conda_dict=conda)}"
  184. else:
  185. raise TypeError(
  186. "conda field received by RuntimeEnvAgent must be "
  187. f"str or dict, not {type(conda).__name__}."
  188. )
  189. else:
  190. uri = None
  191. return uri
  192. def _get_conda_dict_with_ray_inserted(
  193. runtime_env: "RuntimeEnv", # noqa: F821
  194. logger: Optional[logging.Logger] = default_logger,
  195. ) -> Dict[str, Any]:
  196. """Returns the conda spec with the Ray and `python` dependency inserted."""
  197. conda_dict = json.loads(runtime_env.conda_config())
  198. assert conda_dict is not None
  199. ray_pip = current_ray_pip_specifier(logger=logger)
  200. if ray_pip:
  201. extra_pip_dependencies = [ray_pip, "ray[default]"]
  202. elif runtime_env.get_extension("_inject_current_ray"):
  203. extra_pip_dependencies = _resolve_install_from_source_ray_dependencies()
  204. else:
  205. extra_pip_dependencies = []
  206. conda_dict = inject_dependencies(
  207. conda_dict, _current_py_version(), extra_pip_dependencies
  208. )
  209. return conda_dict
  210. class CondaPlugin(RuntimeEnvPlugin):
  211. name = "conda"
  212. def __init__(self, resources_dir: str):
  213. self._resources_dir = os.path.join(resources_dir, "conda")
  214. try_to_create_directory(self._resources_dir)
  215. # It is not safe for multiple processes to install conda envs
  216. # concurrently, even if the envs are different, so use a global
  217. # lock for all conda installs and deletions.
  218. # See https://github.com/ray-project/ray/issues/17086
  219. self._installs_and_deletions_file_lock = os.path.join(
  220. self._resources_dir, "ray-conda-installs-and-deletions.lock"
  221. )
  222. # A set of named conda environments (instead of yaml or dict)
  223. # that are validated to exist.
  224. # NOTE: It has to be only used within the same thread, which
  225. # is an event loop.
  226. # Also, we don't need to GC this field because it is pretty small.
  227. self._validated_named_conda_env = set()
  228. def _get_path_from_hash(self, hash: str) -> str:
  229. """Generate a path from the hash of a conda or pip spec.
  230. The output path also functions as the name of the conda environment
  231. when using the `--prefix` option to `conda create` and `conda remove`.
  232. Example output:
  233. /tmp/ray/session_2021-11-03_16-33-59_356303_41018/runtime_resources
  234. /conda/ray-9a7972c3a75f55e976e620484f58410c920db091
  235. """
  236. return os.path.join(self._resources_dir, hash)
  237. def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
  238. """Return the conda URI from the RuntimeEnv if it exists, else return []."""
  239. conda_uri = runtime_env.conda_uri()
  240. if conda_uri:
  241. return [conda_uri]
  242. return []
  243. def delete_uri(
  244. self, uri: str, logger: Optional[logging.Logger] = default_logger
  245. ) -> int:
  246. """Delete URI and return the number of bytes deleted."""
  247. logger.info(f"Got request to delete URI {uri}")
  248. protocol, hash = parse_uri(uri)
  249. if protocol != Protocol.CONDA:
  250. raise ValueError(
  251. "CondaPlugin can only delete URIs with protocol "
  252. f"conda. Received protocol {protocol}, URI {uri}"
  253. )
  254. conda_env_path = self._get_path_from_hash(hash)
  255. local_dir_size = get_directory_size_bytes(conda_env_path)
  256. with FileLock(self._installs_and_deletions_file_lock):
  257. successful = delete_conda_env(prefix=conda_env_path, logger=logger)
  258. if not successful:
  259. logger.warning(f"Error when deleting conda env {conda_env_path}. ")
  260. return 0
  261. return local_dir_size
  262. async def create(
  263. self,
  264. uri: Optional[str],
  265. runtime_env: "RuntimeEnv", # noqa: F821
  266. context: RuntimeEnvContext,
  267. logger: logging.Logger = default_logger,
  268. ) -> int:
  269. if not runtime_env.has_conda():
  270. return 0
  271. def _create():
  272. result = parse_and_validate_conda(runtime_env.get("conda"))
  273. if isinstance(result, str):
  274. # The conda env name is given.
  275. # In this case, we only verify if the given
  276. # conda env exists.
  277. # If the env is already validated, do nothing.
  278. if result in self._validated_named_conda_env:
  279. return 0
  280. conda_info = get_conda_info_json()
  281. envs = get_conda_envs(conda_info)
  282. # We accept `result` as a conda name or full path.
  283. if not any(result == env[0] or result == env[1] for env in envs):
  284. raise ValueError(
  285. f"The given conda environment '{result}' "
  286. f"from the runtime env {runtime_env} doesn't "
  287. "exist from the output of `conda info --json`. "
  288. "You can only specify an env that already exists. "
  289. f"Please make sure to create an env {result} "
  290. )
  291. self._validated_named_conda_env.add(result)
  292. return 0
  293. logger.debug(
  294. "Setting up conda for runtime_env: " f"{runtime_env.serialize()}"
  295. )
  296. protocol, hash = parse_uri(uri)
  297. conda_env_name = self._get_path_from_hash(hash)
  298. conda_dict = _get_conda_dict_with_ray_inserted(runtime_env, logger=logger)
  299. logger.info(f"Setting up conda environment with {runtime_env}")
  300. with FileLock(self._installs_and_deletions_file_lock):
  301. try:
  302. conda_yaml_file = os.path.join(
  303. self._resources_dir, "environment.yml"
  304. )
  305. with open(conda_yaml_file, "w") as file:
  306. yaml.dump(conda_dict, file)
  307. create_conda_env_if_needed(
  308. conda_yaml_file, prefix=conda_env_name, logger=logger
  309. )
  310. finally:
  311. os.remove(conda_yaml_file)
  312. if runtime_env.get_extension("_inject_current_ray"):
  313. _inject_ray_to_conda_site(conda_path=conda_env_name, logger=logger)
  314. logger.info(f"Finished creating conda environment at {conda_env_name}")
  315. return get_directory_size_bytes(conda_env_name)
  316. loop = get_or_create_event_loop()
  317. return await loop.run_in_executor(None, _create)
  318. def modify_context(
  319. self,
  320. uris: List[str],
  321. runtime_env: "RuntimeEnv", # noqa: F821
  322. context: RuntimeEnvContext,
  323. logger: Optional[logging.Logger] = default_logger,
  324. ):
  325. if not runtime_env.has_conda():
  326. return
  327. if runtime_env.conda_env_name():
  328. conda_env_name = runtime_env.conda_env_name()
  329. else:
  330. protocol, hash = parse_uri(runtime_env.conda_uri())
  331. conda_env_name = self._get_path_from_hash(hash)
  332. context.py_executable = "python"
  333. context.command_prefix += get_conda_activate_commands(conda_env_name)