pip.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. import asyncio
  2. import hashlib
  3. import json
  4. import logging
  5. import os
  6. import shutil
  7. import sys
  8. from asyncio import create_task, get_running_loop
  9. from typing import Dict, List, Optional
  10. from ray._common.utils import try_to_create_directory
  11. from ray._private.runtime_env import dependency_utils, virtualenv_utils
  12. from ray._private.runtime_env.packaging import Protocol, parse_uri
  13. from ray._private.runtime_env.plugin import RuntimeEnvPlugin
  14. from ray._private.runtime_env.utils import check_output_cmd
  15. from ray._private.utils import get_directory_size_bytes
  16. default_logger = logging.getLogger(__name__)
  17. def _get_pip_hash(pip_dict: Dict) -> str:
  18. serialized_pip_spec = json.dumps(pip_dict, sort_keys=True)
  19. hash_val = hashlib.sha1(serialized_pip_spec.encode("utf-8")).hexdigest()
  20. return hash_val
  21. def get_uri(runtime_env: Dict) -> Optional[str]:
  22. """Return `"pip://<hashed_dependencies>"`, or None if no GC required."""
  23. pip = runtime_env.get("pip")
  24. if pip is not None:
  25. if isinstance(pip, dict):
  26. uri = "pip://" + _get_pip_hash(pip_dict=pip)
  27. elif isinstance(pip, list):
  28. uri = "pip://" + _get_pip_hash(pip_dict=dict(packages=pip))
  29. else:
  30. raise TypeError(
  31. "pip field received by RuntimeEnvAgent must be "
  32. f"list or dict, not {type(pip).__name__}."
  33. )
  34. else:
  35. uri = None
  36. return uri
  37. class PipProcessor:
  38. def __init__(
  39. self,
  40. target_dir: str,
  41. runtime_env: "RuntimeEnv", # noqa: F821
  42. logger: Optional[logging.Logger] = default_logger,
  43. ):
  44. try:
  45. import virtualenv # noqa: F401 ensure virtualenv exists.
  46. except ImportError:
  47. raise RuntimeError(
  48. f"Please install virtualenv "
  49. f"`{sys.executable} -m pip install virtualenv`"
  50. f"to enable pip runtime env."
  51. )
  52. logger.debug("Setting up pip for runtime_env: %s", runtime_env)
  53. self._target_dir = target_dir
  54. self._runtime_env = runtime_env
  55. self._logger = logger
  56. self._pip_config = self._runtime_env.pip_config()
  57. self._pip_env = os.environ.copy()
  58. self._pip_env.update(self._runtime_env.env_vars())
  59. @classmethod
  60. async def _ensure_pip_version(
  61. cls,
  62. path: str,
  63. pip_version: Optional[str],
  64. cwd: str,
  65. pip_env: Dict,
  66. logger: logging.Logger,
  67. ):
  68. """Run the pip command to reinstall pip to the specified version."""
  69. if not pip_version:
  70. return
  71. python = virtualenv_utils.get_virtualenv_python(path)
  72. # Ensure pip version.
  73. pip_reinstall_cmd = [
  74. python,
  75. "-m",
  76. "pip",
  77. "install",
  78. "--disable-pip-version-check",
  79. f"pip{pip_version}",
  80. ]
  81. logger.info("Installing pip with version %s", pip_version)
  82. await check_output_cmd(pip_reinstall_cmd, logger=logger, cwd=cwd, env=pip_env)
  83. async def _pip_check(
  84. self,
  85. path: str,
  86. pip_check: bool,
  87. cwd: str,
  88. pip_env: Dict,
  89. logger: logging.Logger,
  90. ):
  91. """Run the pip check command to check python dependency conflicts.
  92. If exists conflicts, the exit code of pip check command will be non-zero.
  93. """
  94. if not pip_check:
  95. logger.info("Skip pip check.")
  96. return
  97. python = virtualenv_utils.get_virtualenv_python(path)
  98. await check_output_cmd(
  99. [python, "-m", "pip", "check", "--disable-pip-version-check"],
  100. logger=logger,
  101. cwd=cwd,
  102. env=pip_env,
  103. )
  104. logger.info("Pip check on %s successfully.", path)
  105. async def _install_pip_packages(
  106. self,
  107. path: str,
  108. pip_packages: List[str],
  109. cwd: str,
  110. pip_env: Dict,
  111. logger: logging.Logger,
  112. ):
  113. virtualenv_path = virtualenv_utils.get_virtualenv_path(path)
  114. python = virtualenv_utils.get_virtualenv_python(path)
  115. # TODO(fyrestone): Support -i, --no-deps, --no-cache-dir, ...
  116. pip_requirements_file = dependency_utils.get_requirements_file(
  117. path, pip_packages
  118. )
  119. # Avoid blocking the event loop.
  120. loop = get_running_loop()
  121. await loop.run_in_executor(
  122. None,
  123. dependency_utils.gen_requirements_txt,
  124. pip_requirements_file,
  125. pip_packages,
  126. )
  127. # Install all dependencies
  128. # The default options for pip install are
  129. #
  130. # --disable-pip-version-check
  131. # Don't periodically check PyPI to determine whether a new version
  132. # of pip is available for download.
  133. #
  134. # --no-cache-dir
  135. # Disable the cache, the pip runtime env is a one-time installation,
  136. # and we don't need to handle the pip cache broken.
  137. #
  138. # Allow users to specify their own options to install packages via `pip`.
  139. pip_install_cmd = [
  140. python,
  141. "-m",
  142. "pip",
  143. "install",
  144. "-r",
  145. pip_requirements_file,
  146. ]
  147. pip_opt_list = self._pip_config.get(
  148. "pip_install_options", ["--disable-pip-version-check", "--no-cache-dir"]
  149. )
  150. pip_install_cmd.extend(pip_opt_list)
  151. logger.info("Installing python requirements to %s", virtualenv_path)
  152. await check_output_cmd(pip_install_cmd, logger=logger, cwd=cwd, env=pip_env)
  153. async def _run(self):
  154. path = self._target_dir
  155. logger = self._logger
  156. pip_packages = self._pip_config["packages"]
  157. # We create an empty directory for exec cmd so that the cmd will
  158. # run more stable. e.g. if cwd has ray, then checking ray will
  159. # look up ray in cwd instead of site packages.
  160. exec_cwd = os.path.join(path, "exec_cwd")
  161. os.makedirs(exec_cwd, exist_ok=True)
  162. try:
  163. await virtualenv_utils.create_or_get_virtualenv(path, exec_cwd, logger)
  164. python = virtualenv_utils.get_virtualenv_python(path)
  165. async with dependency_utils.check_ray(python, exec_cwd, logger):
  166. # Ensure pip version.
  167. await self._ensure_pip_version(
  168. path,
  169. self._pip_config.get("pip_version", None),
  170. exec_cwd,
  171. self._pip_env,
  172. logger,
  173. )
  174. # Install pip packages.
  175. await self._install_pip_packages(
  176. path,
  177. pip_packages,
  178. exec_cwd,
  179. self._pip_env,
  180. logger,
  181. )
  182. # Check python environment for conflicts.
  183. await self._pip_check(
  184. path,
  185. self._pip_config.get("pip_check", False),
  186. exec_cwd,
  187. self._pip_env,
  188. logger,
  189. )
  190. except Exception:
  191. logger.info("Delete incomplete virtualenv: %s", path)
  192. shutil.rmtree(path, ignore_errors=True)
  193. logger.exception("Failed to install pip packages.")
  194. raise
  195. def __await__(self):
  196. return self._run().__await__()
  197. class PipPlugin(RuntimeEnvPlugin):
  198. name = "pip"
  199. def __init__(self, resources_dir: str):
  200. self._pip_resources_dir = os.path.join(resources_dir, "pip")
  201. self._creating_task = {}
  202. # Maps a URI to a lock that is used to prevent multiple concurrent
  203. # installs of the same virtualenv, see #24513
  204. self._create_locks: Dict[str, asyncio.Lock] = {}
  205. # Key: created hashes. Value: size of the pip dir.
  206. self._created_hash_bytes: Dict[str, int] = {}
  207. try_to_create_directory(self._pip_resources_dir)
  208. def _get_path_from_hash(self, hash_val: str) -> str:
  209. """Generate a path from the hash of a pip spec.
  210. Example output:
  211. /tmp/ray/session_2021-11-03_16-33-59_356303_41018/runtime_resources
  212. /pip/ray-9a7972c3a75f55e976e620484f58410c920db091
  213. """
  214. return os.path.join(self._pip_resources_dir, hash_val)
  215. def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
  216. """Return the pip URI from the RuntimeEnv if it exists, else return []."""
  217. pip_uri = runtime_env.pip_uri()
  218. if pip_uri:
  219. return [pip_uri]
  220. return []
  221. def delete_uri(
  222. self, uri: str, logger: Optional[logging.Logger] = default_logger
  223. ) -> int:
  224. """Delete URI and return the number of bytes deleted."""
  225. logger.info("Got request to delete pip URI %s", uri)
  226. protocol, hash_val = parse_uri(uri)
  227. if protocol != Protocol.PIP:
  228. raise ValueError(
  229. "PipPlugin can only delete URIs with protocol "
  230. f"pip. Received protocol {protocol}, URI {uri}"
  231. )
  232. # Cancel running create task.
  233. task = self._creating_task.pop(hash_val, None)
  234. if task is not None:
  235. task.cancel()
  236. del self._created_hash_bytes[hash_val]
  237. pip_env_path = self._get_path_from_hash(hash_val)
  238. local_dir_size = get_directory_size_bytes(pip_env_path)
  239. del self._create_locks[uri]
  240. try:
  241. shutil.rmtree(pip_env_path)
  242. except OSError as e:
  243. logger.warning(f"Error when deleting pip env {pip_env_path}: {str(e)}")
  244. return 0
  245. return local_dir_size
  246. async def create(
  247. self,
  248. uri: str,
  249. runtime_env: "RuntimeEnv", # noqa: F821
  250. context: "RuntimeEnvContext", # noqa: F821
  251. logger: Optional[logging.Logger] = default_logger,
  252. ) -> int:
  253. if not runtime_env.has_pip():
  254. return 0
  255. protocol, hash_val = parse_uri(uri)
  256. target_dir = self._get_path_from_hash(hash_val)
  257. async def _create_for_hash():
  258. await PipProcessor(
  259. target_dir,
  260. runtime_env,
  261. logger,
  262. )
  263. loop = get_running_loop()
  264. return await loop.run_in_executor(
  265. None, get_directory_size_bytes, target_dir
  266. )
  267. if uri not in self._create_locks:
  268. # async lock to prevent the same virtualenv being concurrently installed
  269. self._create_locks[uri] = asyncio.Lock()
  270. async with self._create_locks[uri]:
  271. if hash_val in self._created_hash_bytes:
  272. return self._created_hash_bytes[hash_val]
  273. self._creating_task[hash_val] = task = create_task(_create_for_hash())
  274. task.add_done_callback(lambda _: self._creating_task.pop(hash_val, None))
  275. pip_dir_bytes = await task
  276. self._created_hash_bytes[hash_val] = pip_dir_bytes
  277. return pip_dir_bytes
  278. def modify_context(
  279. self,
  280. uris: List[str],
  281. runtime_env: "RuntimeEnv", # noqa: F821
  282. context: "RuntimeEnvContext", # noqa: F821
  283. logger: logging.Logger = default_logger,
  284. ):
  285. if not runtime_env.has_pip():
  286. return
  287. # PipPlugin only uses a single URI.
  288. uri = uris[0]
  289. # Update py_executable.
  290. protocol, hash_val = parse_uri(uri)
  291. target_dir = self._get_path_from_hash(hash_val)
  292. virtualenv_python = virtualenv_utils.get_virtualenv_python(target_dir)
  293. if not os.path.exists(virtualenv_python):
  294. raise ValueError(
  295. f"Local directory {target_dir} for URI {uri} does "
  296. "not exist on the cluster. Something may have gone wrong while "
  297. "installing the runtime_env `pip` packages."
  298. )
  299. context.py_executable = virtualenv_python
  300. context.command_prefix += virtualenv_utils.get_virtualenv_activate_command(
  301. target_dir
  302. )