import asyncio import hashlib import logging import os import shutil import sys import time from pathlib import Path from tempfile import TemporaryDirectory from typing import Callable, List, Optional, Tuple from urllib.parse import urlparse from zipfile import ZipFile from filelock import FileLock from ray._private.path_utils import is_path from ray._private.ray_constants import ( GRPC_CPP_MAX_MESSAGE_SIZE, RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT, RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR, ) from ray._private.runtime_env.conda_utils import exec_cmd_stream_to_logger from ray._private.runtime_env.protocol import Protocol from ray._private.thirdparty.pathspec import PathSpec from ray._raylet import GcsClient from ray.experimental.internal_kv import ( _internal_kv_exists, _internal_kv_put, _pin_runtime_env_uri, ) from ray.util.annotations import DeveloperAPI default_logger = logging.getLogger(__name__) # If an individual file is beyond this size, print a warning. FILE_SIZE_WARNING = 10 * 1024 * 1024 # 10MiB # The size is bounded by the max gRPC message size. # Keep in sync with max_grpc_message_size in ray_config_def.h. GCS_STORAGE_MAX_SIZE = int( os.environ.get("RAY_max_grpc_message_size", GRPC_CPP_MAX_MESSAGE_SIZE) ) RAY_PKG_PREFIX = "_ray_pkg_" RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR = ( "RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING" ) RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR = ( "RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING" ) # The name of the hidden top-level directory that appears when files are # zipped on MacOS. MAC_OS_ZIP_HIDDEN_DIR_NAME = "__MACOSX" def _mib_string(num_bytes: float) -> str: size_mib = float(num_bytes / 1024**2) return f"{size_mib:.2f}MiB" def _to_extended_length_path(path: str) -> str: r"""Convert paths to extended-length format if needed on Windows if needed. Paths on other platforms are returned unchanged. Extended-length paths (\\?\) support paths up to 32,767 characters on Windows instead of 260. Extended-length paths must be normalized (i.e., no "." or ".." components) so this function normalizes the path before applying the prefix. Args: path: The path to convert. Returns: The path with extended-length prefixed path on Windows, unchanged on other platforms. """ # Ensure we always work with strings, not Path objects path = str(path) if sys.platform != "win32": return path # Convert to absolute path and fully normalize to remove any . or .. components # This is critical because extended-length paths disable Windows path normalization abs_path = os.path.normpath(os.path.abspath(path)) # Extended-length path prefix extended_prefix = "\\\\?\\" # Already in extended format if abs_path.startswith(extended_prefix): return abs_path # UNC paths need special handling: \\server\share -> \\?\UNC\server\share if abs_path.startswith("\\\\"): return extended_prefix + "UNC" + abs_path[1:] # Local paths: C:\path -> \\?\C:\path return extended_prefix + abs_path class _AsyncFileLock: """Asyncio version used to prevent blocking event loop.""" def __init__(self, lock_file: str): self.file = FileLock(lock_file) async def __aenter__(self): while True: try: self.file.acquire(timeout=0) return except TimeoutError: await asyncio.sleep(0.1) async def __aexit__(self, exc_type, exc, tb): self.file.release() def _xor_bytes(left: bytes, right: bytes) -> bytes: if left and right: return bytes(a ^ b for (a, b) in zip(left, right)) return left or right def _dir_travel( path: Path, excludes: List[Callable], handler: Callable, include_gitignore: bool, logger: Optional[logging.Logger] = default_logger, ): """Travels the path recursively, calling the handler on each subpath. Respects excludes, which will be called to check if this path is skipped. """ new_excludes = get_excludes_from_ignore_files( path, include_gitignore=include_gitignore, logger=logger ) excludes.extend(new_excludes) skip = any(e(path) for e in excludes) if not skip: try: handler(path) except Exception as e: logger.error(f"Issue with path: {path}") raise e if path.is_dir(): for sub_path in path.iterdir(): _dir_travel( sub_path, excludes, handler, include_gitignore=include_gitignore, logger=logger, ) for _ in range(len(new_excludes)): excludes.pop() def _hash_file_content_or_directory_name( filepath: Path, relative_path: Path, logger: Optional[logging.Logger] = default_logger, ) -> bytes: """Helper function to create hash of a single file or directory. This function hashes the path of the file or directory, and if it's a file, then it hashes its content too. """ BUF_SIZE = 4096 * 1024 sha1 = hashlib.sha1() sha1.update(str(filepath.relative_to(relative_path)).encode()) if not filepath.is_dir(): try: f = filepath.open("rb") except Exception as e: logger.debug( f"Skipping contents of file {filepath} when calculating package hash " f"because the file couldn't be opened: {e}" ) else: try: data = f.read(BUF_SIZE) while len(data) != 0: sha1.update(data) data = f.read(BUF_SIZE) finally: f.close() return sha1.digest() def _hash_file( filepath: Path, relative_path: Path, logger: Optional[logging.Logger] = default_logger, ) -> bytes: """Helper function to create hash of a single file. It hashes the path of the file and its content to create a hash value. """ file_hash = _hash_file_content_or_directory_name( filepath, relative_path, logger=logger ) return _xor_bytes(file_hash, b"0" * 8) def _hash_directory( root: Path, relative_path: Path, excludes: Optional[Callable], include_gitignore: bool, logger: Optional[logging.Logger] = default_logger, ) -> bytes: """Helper function to create hash of a directory. It'll go through all the files in the directory and xor hash(file_name, file_content) to create a hash value. """ hash_val = b"0" * 8 def handler(path: Path): file_hash = _hash_file_content_or_directory_name( path, relative_path, logger=logger ) nonlocal hash_val hash_val = _xor_bytes(hash_val, file_hash) excludes = [] if excludes is None else [excludes] _dir_travel( root, excludes, handler, include_gitignore=include_gitignore, logger=logger ) return hash_val def parse_path(pkg_path: str) -> None: """Parse the path to check it is well-formed and exists.""" path = Path(pkg_path) try: path.resolve(strict=True) except OSError: raise ValueError(f"{path} is not a valid path.") def parse_uri(pkg_uri: str) -> Tuple[Protocol, str]: """ Parse package uri into protocol and package name based on its format. Note that the output of this function is not for handling actual IO, it's only for setting up local directory folders by using package name as path. >>> parse_uri("https://test.com/file.zip") (, 'https_test_com_file.zip') >>> parse_uri("https://test.com/file.whl") (, 'file.whl') """ if is_path(pkg_uri): raise ValueError(f"Expected URI but received path {pkg_uri}") uri = urlparse(pkg_uri) try: protocol = Protocol(uri.scheme) except ValueError as e: raise ValueError( f'Invalid protocol for runtime_env URI "{pkg_uri}". ' f"Supported protocols: {Protocol._member_names_}. Original error: {e}" ) if protocol in Protocol.remote_protocols(): if uri.path.endswith(".whl"): # Don't modify the .whl filename. See # https://peps.python.org/pep-0427/#file-name-convention # for more information. package_name = uri.path.split("/")[-1] else: package_name = f"{protocol.value}_{uri.netloc}{uri.path}" disallowed_chars = ["/", ":", "@", "+", " ", "(", ")"] for disallowed_char in disallowed_chars: package_name = package_name.replace(disallowed_char, "_") # Remove all periods except the last, which is part of the # file extension package_name = package_name.replace(".", "_", package_name.count(".") - 1) else: package_name = uri.netloc return (protocol, package_name) def is_zip_uri(uri: str) -> bool: try: protocol, path = parse_uri(uri) except ValueError: return False return Path(path).suffix == ".zip" def is_whl_uri(uri: str) -> bool: try: _, path = parse_uri(uri) except ValueError: return False return Path(path).suffix == ".whl" def is_jar_uri(uri: str) -> bool: try: _, path = parse_uri(uri) except ValueError: return False return Path(path).suffix == ".jar" def _get_excludes(path: Path, excludes: List[str]) -> Callable: path = path.absolute() pathspec = PathSpec.from_lines("gitwildmatch", excludes) def match(p: Path): path_str = str(p.absolute().relative_to(path)) return pathspec.match_file(path_str) return match def _get_ignore_file(path: Path, ignore_file: str) -> Optional[Callable]: """Returns a function that returns True if the path should be excluded. Returns None if there is no ignore_file in the path. Args: path: The path to the directory to check for an ignore file. ignore_file: The name of the ignore file. Returns: A function that returns True if the path should be excluded. """ path = path.absolute() ignore_file = path / ignore_file if ignore_file.is_file(): with ignore_file.open("r") as f: pathspec = PathSpec.from_lines("gitwildmatch", f.readlines()) def match(p: Path): path_str = str(p.absolute().relative_to(path)) return pathspec.match_file(path_str) return match else: return None def get_excludes_from_ignore_files( path: Path, include_gitignore: bool, logger: Optional[logging.Logger] = default_logger, ) -> List[Callable]: """Get exclusion functions from .gitignore and .rayignore files in the current path. Args: path: The path to check for ignore files. include_gitignore: Whether to respect .gitignore files. logger: Logger to use. Returns: List[Callable]: List of exclusion functions. Each function takes a Path and returns True if the path should be excluded based on the ignore patterns in the respective ignore file. """ ignore_files = [] to_ignore: List[Optional[Callable]] = [] if include_gitignore: g = _get_ignore_file(path, ignore_file=".gitignore") if g is not None: to_ignore.append(g) ignore_files.append(path / ".gitignore") r = _get_ignore_file(path, ignore_file=".rayignore") if r is not None: to_ignore.append(r) ignore_files.append(path / ".rayignore") if ignore_files: logger.info(f"Ignoring upload to cluster for these files: {ignore_files}") return to_ignore def pin_runtime_env_uri(uri: str, *, expiration_s: Optional[int] = None) -> None: """Pin a reference to a runtime_env URI in the GCS on a timeout. This is used to avoid premature eviction in edge conditions for job reference counting. See https://github.com/ray-project/ray/pull/24719. Packages are uploaded to GCS in order to be downloaded by a runtime env plugin (e.g. working_dir, py_modules) after the job starts. This function adds a temporary reference to the package in the GCS to prevent it from being deleted before the job starts. (See #23423 for the bug where this happened.) If this reference didn't have an expiration, then if the script exited (e.g. via Ctrl-C) before the job started, the reference would never be removed, so the package would never be deleted. """ if expiration_s is None: expiration_s = int( os.environ.get( RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR, RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT, ) ) elif not isinstance(expiration_s, int): raise ValueError(f"expiration_s must be an int, got {type(expiration_s)}.") if expiration_s < 0: raise ValueError(f"expiration_s must be >= 0, got {expiration_s}.") elif expiration_s > 0: _pin_runtime_env_uri(uri, expiration_s=expiration_s) def _store_package_in_gcs( pkg_uri: str, data: bytes, logger: Optional[logging.Logger] = default_logger, ) -> int: """Stores package data in the Global Control Store (GCS). Args: pkg_uri: The GCS key to store the data in. data: The serialized package's bytes to store in the GCS. logger (Optional[logging.Logger]): The logger used by this function. Return: int: Size of data Raises: RuntimeError: If the upload to the GCS fails. ValueError: If the data's size exceeds GCS_STORAGE_MAX_SIZE. """ file_size = len(data) size_str = _mib_string(file_size) if len(data) >= GCS_STORAGE_MAX_SIZE: raise ValueError( f"Package size ({size_str}) exceeds the maximum size of " f"{_mib_string(GCS_STORAGE_MAX_SIZE)}. To exclude large files, " "add them to '.gitignore' or '.rayignore' files, or " "use the 'excludes' option in the runtime_env, or provide a " "URI of a remote zip file. For more information, refer to " "https://docs.ray.io/en/latest/ray-core/handling-dependencies.html#api-reference." # noqa ) logger.info(f"Pushing file package '{pkg_uri}' ({size_str}) to Ray cluster...") try: if os.environ.get(RAY_RUNTIME_ENV_FAIL_UPLOAD_FOR_TESTING_ENV_VAR): raise RuntimeError( "Simulating failure to upload package for testing purposes." ) _internal_kv_put(pkg_uri, data) except Exception as e: raise RuntimeError( "Failed to store package in the GCS.\n" f" - GCS URI: {pkg_uri}\n" f" - Package data ({size_str}): {data[:15]}...\n" ) from e logger.info(f"Successfully pushed file package '{pkg_uri}'.") return len(data) def _get_local_path(base_directory: str, pkg_uri: str) -> str: _, pkg_name = parse_uri(pkg_uri) return os.path.join(base_directory, pkg_name) def _zip_files( path_str: str, excludes: List[str], output_path: str, include_gitignore: bool, include_parent_dir: bool = False, logger: Optional[logging.Logger] = default_logger, ) -> None: """Zip the target file or directory and write it to the output_path. path_str: The file or directory to zip. excludes (List(str)): The directories or file to be excluded. output_path: The output path for the zip file. include_parent_dir: If true, includes the top-level directory as a directory inside the zip file. """ pkg_file = Path(output_path).absolute() # Use extended-length paths on Windows to avoid MAX_PATH limitations extended_pkg_file = _to_extended_length_path(str(pkg_file)) with ZipFile(extended_pkg_file, "w", strict_timestamps=False) as zip_handler: # Put all files in the directory into the zip file. file_path = Path(path_str).absolute() dir_path = file_path if file_path.is_file(): dir_path = file_path.parent def handler(path: Path): # Pack this path if it's an empty directory or it's a file. if path.is_dir() and next(path.iterdir(), None) is None or path.is_file(): file_size = path.stat().st_size if file_size >= FILE_SIZE_WARNING: logger.warning( f"File {path} is very large " f"({_mib_string(file_size)}). Consider adding this " "file to the 'excludes' list to skip uploading it: " "`ray.init(..., " f"runtime_env={{'excludes': ['{path}']}})`" ) to_path = path.relative_to(dir_path) if include_parent_dir: to_path = dir_path.name / to_path zip_handler.write(path, to_path) excludes = [_get_excludes(file_path, excludes)] _dir_travel( file_path, excludes, handler, include_gitignore=include_gitignore, logger=logger, ) def package_exists(pkg_uri: str) -> bool: """Check whether the package with given URI exists or not. Args: pkg_uri: The uri of the package Return: True for package existing and False for not. """ protocol, pkg_name = parse_uri(pkg_uri) if protocol == Protocol.GCS: return _internal_kv_exists(pkg_uri) else: raise NotImplementedError(f"Protocol {protocol} is not supported") def get_uri_for_package(package: Path) -> str: """Get a content-addressable URI from a package's contents.""" if package.suffix == ".whl": # Wheel file names include the Python package name, version # and tags, so it is already effectively content-addressed. return "{protocol}://{whl_filename}".format( protocol=Protocol.GCS.value, whl_filename=package.name ) else: hash_val = hashlib.sha1(package.read_bytes()).hexdigest() return "{protocol}://{pkg_name}.zip".format( protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val ) def get_uri_for_file(file: str) -> str: """Get a content-addressable URI from a file's content. This function generates the name of the package by the file. The final package name is _ray_pkg_.zip of this package, where HASH_VAL is the hash value of the file. For example: _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip Examples: >>> get_uri_for_file("/my_file.py") # doctest: +SKIP _ray_pkg_af2734982a741.zip Args: file: The file. Returns: URI (str) Raises: ValueError: If the file doesn't exist. """ filepath = Path(file).absolute() if not filepath.exists() or not filepath.is_file(): raise ValueError(f"File {filepath} must be an existing file") hash_val = _hash_file(filepath, filepath.parent) return "{protocol}://{pkg_name}.zip".format( protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val.hex() ) def get_uri_for_directory( directory: str, include_gitignore: bool, excludes: Optional[List[str]] = None, ) -> str: """Get a content-addressable URI from a directory's contents. This function generates the name of the package by the directory. It'll go through all the files in the directory and hash the contents of the files to get the hash value of the package. The final package name is _ray_pkg_.zip of this package. For example: _ray_pkg_029f88d5ecc55e1e4d64fc6e388fd103.zip Examples: >>> get_uri_for_directory("/my_directory") # doctest: +SKIP _ray_pkg_af2734982a741.zip Args: directory: The directory. include_gitignore: Whether to respect .gitignore files. excludes (list[str]): The dir or files that should be excluded. Returns: URI (str) Raises: ValueError: If the directory doesn't exist. """ if excludes is None: excludes = [] directory = Path(directory).absolute() if not directory.exists() or not directory.is_dir(): raise ValueError(f"directory {directory} must be an existing directory") hash_val = _hash_directory( directory, directory, _get_excludes(directory, excludes), include_gitignore=include_gitignore, ) return "{protocol}://{pkg_name}.zip".format( protocol=Protocol.GCS.value, pkg_name=RAY_PKG_PREFIX + hash_val.hex() ) def upload_package_to_gcs(pkg_uri: str, pkg_bytes: bytes) -> None: """Upload a local package to GCS. Args: pkg_uri: The URI of the package, e.g. gcs://my_package.zip pkg_bytes: The data to be uploaded. Raises: RuntimeError: If the upload fails. ValueError: If the pkg_uri is a remote path or if the data's size exceeds GCS_STORAGE_MAX_SIZE. NotImplementedError: If the protocol of the URI is not supported. """ protocol, pkg_name = parse_uri(pkg_uri) if protocol == Protocol.GCS: _store_package_in_gcs(pkg_uri, pkg_bytes) elif protocol in Protocol.remote_protocols(): raise ValueError( "upload_package_to_gcs should not be called with a remote path." ) else: raise NotImplementedError(f"Protocol {protocol} is not supported") def create_package( module_path: str, target_path: Path, include_gitignore: bool, include_parent_dir: bool = False, excludes: Optional[List[str]] = None, logger: Optional[logging.Logger] = default_logger, ): if excludes is None: excludes = [] if logger is None: logger = default_logger if not target_path.exists(): logger.info(f"Creating a file package for local module '{module_path}'.") _zip_files( module_path, excludes, str(target_path), include_gitignore=include_gitignore, include_parent_dir=include_parent_dir, logger=logger, ) def upload_package_if_needed( pkg_uri: str, base_directory: str, module_path: str, include_gitignore: bool, include_parent_dir: bool = False, excludes: Optional[List[str]] = None, logger: Optional[logging.Logger] = default_logger, ) -> bool: """Upload the contents of the directory under the given URI. This will first create a temporary zip file under the passed base_directory. If the package already exists in storage, this is a no-op. Args: pkg_uri: URI of the package to upload. base_directory: Directory where package files are stored. module_path: The module to be uploaded, either a single .py file or a directory. include_parent_dir: If true, includes the top-level directory as a directory inside the zip file. excludes: List specifying files to exclude. include_gitignore: Whether to respect .gitignore files. Default is True. Raises: RuntimeError: If the upload fails. ValueError: If the pkg_uri is a remote path or if the data's size exceeds GCS_STORAGE_MAX_SIZE. NotImplementedError: If the protocol of the URI is not supported. """ if excludes is None: excludes = [] if logger is None: logger = default_logger pin_runtime_env_uri(pkg_uri) if package_exists(pkg_uri): return False package_file = Path(_get_local_path(base_directory, pkg_uri)) # Make the temporary zip file name unique so that it doesn't conflict with # concurrent upload_package_if_needed calls with the same pkg_uri. # See https://github.com/ray-project/ray/issues/47471. package_file = package_file.with_name( f"{time.time_ns()}_{os.getpid()}_{package_file.name}" ) create_package( module_path, package_file, include_gitignore=include_gitignore, include_parent_dir=include_parent_dir, excludes=excludes, ) package_file_bytes = package_file.read_bytes() # Remove the local file to avoid accumulating temporary zip files. package_file.unlink() upload_package_to_gcs(pkg_uri, package_file_bytes) return True def get_local_dir_from_uri(uri: str, base_directory: str) -> Path: """Return the local directory corresponding to this URI.""" pkg_file = Path(_get_local_path(base_directory, uri)) local_dir = pkg_file.with_suffix("") return local_dir @DeveloperAPI async def download_and_unpack_package( pkg_uri: str, base_directory: str, gcs_client: Optional[GcsClient] = None, logger: Optional[logging.Logger] = default_logger, overwrite: bool = False, ) -> str: """Download the package corresponding to this URI and unpack it if zipped. Will be written to a file or directory named {base_directory}/{uri}. Returns the path to this file or directory. Args: pkg_uri: URI of the package to download. base_directory: Directory to use as the parent directory of the target directory for the unpacked files. gcs_client: Client to use for downloading from the GCS. logger: The logger to use. overwrite: If True, overwrite the existing package. Returns: Path to the local directory containing the unpacked package files. Raises: IOError: If the download fails. ImportError: If smart_open is not installed and a remote URI is used. NotImplementedError: If the protocol of the URI is not supported. ValueError: If the GCS client is not provided when downloading from GCS, or if package URI is invalid. """ pkg_file = Path(_get_local_path(base_directory, pkg_uri)) if pkg_file.suffix == "": raise ValueError( f"Invalid package URI: {pkg_uri}." "URI must have a file extension and the URI must be valid." ) async with _AsyncFileLock(str(pkg_file) + ".lock"): if logger is None: logger = default_logger logger.debug(f"Fetching package for URI: {pkg_uri}") local_dir = get_local_dir_from_uri(pkg_uri, base_directory) assert local_dir != pkg_file, "Invalid pkg_file!" download_package: bool = True if local_dir.exists() and not overwrite: download_package = False assert local_dir.is_dir(), f"{local_dir} is not a directory" elif local_dir.exists(): logger.info(f"Removing {local_dir} with pkg_file {pkg_file}") shutil.rmtree(local_dir) if download_package: protocol, _ = parse_uri(pkg_uri) logger.info( f"Downloading package from {pkg_uri} to {pkg_file} " f"with protocol {protocol}" ) if protocol == Protocol.GCS: if gcs_client is None: raise ValueError( "GCS client must be provided to download from GCS." ) # Download package from the GCS. code = await gcs_client.async_internal_kv_get( pkg_uri.encode(), namespace=None, timeout=None ) if os.environ.get(RAY_RUNTIME_ENV_FAIL_DOWNLOAD_FOR_TESTING_ENV_VAR): code = None if code is None: raise IOError( f"Failed to download runtime_env file package {pkg_uri} " "from the GCS to the Ray worker node. The package may " "have prematurely been deleted from the GCS due to a " "long upload time or a problem with Ray. Try setting the " "environment variable " f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_ENV_VAR} " " to a value larger than the upload time in seconds " "(the default is " f"{RAY_RUNTIME_ENV_URI_PIN_EXPIRATION_S_DEFAULT}). " "If this fails, try re-running " "after making any change to a file in the file package." ) code = code or b"" pkg_file.write_bytes(code) if is_zip_uri(pkg_uri): unzip_package( package_path=pkg_file, target_dir=local_dir, remove_top_level_directory=False, unlink_zip=True, logger=logger, ) else: return str(pkg_file) elif protocol in Protocol.remote_protocols(): protocol.download_remote_uri(source_uri=pkg_uri, dest_file=pkg_file) if pkg_file.suffix in [".zip", ".jar"]: unzip_package( package_path=pkg_file, target_dir=local_dir, remove_top_level_directory=True, unlink_zip=True, logger=logger, ) elif pkg_file.suffix == ".whl": return str(pkg_file) else: raise NotImplementedError( f"Package format {pkg_file.suffix} is ", "not supported for remote protocols", ) else: raise NotImplementedError(f"Protocol {protocol} is not supported") return str(local_dir) def get_top_level_dir_from_compressed_package(package_path: str): """ If compressed package at package_path contains a single top-level directory, returns the name of the top-level directory. Otherwise, returns None. Ignores a second top-level directory if it is named __MACOSX. """ package_zip = ZipFile(package_path, "r") top_level_directory = None def is_top_level_file(file_name): return "/" not in file_name def base_dir_name(file_name): return file_name.split("/")[0] for file_name in package_zip.namelist(): if top_level_directory is None: # Cache the top_level_directory name when checking # the first file in the zipped package if is_top_level_file(file_name): return None else: # Top-level directory, or non-top-level file or directory dir_name = base_dir_name(file_name) if dir_name == MAC_OS_ZIP_HIDDEN_DIR_NAME: continue top_level_directory = dir_name else: # Confirm that all other files # belong to the same top_level_directory if is_top_level_file(file_name) or base_dir_name(file_name) not in [ top_level_directory, MAC_OS_ZIP_HIDDEN_DIR_NAME, ]: return None return top_level_directory def remove_dir_from_filepaths(base_dir: str, rdir: str): """ base_dir: String path of the directory containing rdir rdir: String path of directory relative to base_dir whose contents should be moved to its base_dir, its parent directory Removes rdir from the filepaths of all files and directories inside it. In other words, moves all the files inside rdir to the directory that contains rdir. Assumes base_dir's contents and rdir's contents have no name conflicts. """ # Move rdir to a temporary directory, so its contents can be moved to # base_dir without any name conflicts with TemporaryDirectory() as tmp_dir: # Apply extended-length path to temp directory to handle long paths extended_tmp_dir = _to_extended_length_path(tmp_dir) # shutil.move() is used instead of os.rename() in case rdir and tmp_dir # are located on separate file systems shutil.move(os.path.join(base_dir, rdir), os.path.join(extended_tmp_dir, rdir)) # Shift children out of rdir and into base_dir rdir_children = os.listdir(os.path.join(extended_tmp_dir, rdir)) for child in rdir_children: shutil.move( os.path.join(extended_tmp_dir, rdir, child), os.path.join(base_dir, child), ) def unzip_package( package_path: str, target_dir: str, remove_top_level_directory: bool, unlink_zip: bool, logger: Optional[logging.Logger] = default_logger, ) -> None: """ Unzip the compressed package contained at package_path to target_dir. If remove_top_level_directory is True and the top level consists of a a single directory (or possibly also a second hidden directory named __MACOSX at the top level arising from macOS's zip command), the function will automatically remove the top-level directory and store the contents directly in target_dir. Otherwise, if remove_top_level_directory is False or if the top level consists of multiple files or directories (not counting __MACOS), the zip contents will be stored in target_dir. Args: package_path: String path of the compressed package to unzip. target_dir: String path of the directory to store the unzipped contents. remove_top_level_directory: Whether to remove the top-level directory from the zip contents. unlink_zip: Whether to unlink the zip file stored at package_path. logger: Optional logger to use for logging. """ # Use extended-length paths on Windows to avoid MAX_PATH limitations extended_target_dir = _to_extended_length_path(target_dir) try: os.mkdir(extended_target_dir) except FileExistsError: logger.info(f"Directory at {target_dir} already exists") logger.debug(f"Unpacking {package_path} to {extended_target_dir}") with ZipFile(str(package_path), "r") as zip_ref: # ZipFile.extractall() doesn't support extended paths # on Windows, which are needed to handle paths longer than 260 # characters, so we implement our own extraction logic here. for member in zip_ref.namelist(): # Build the full extraction path with extended-length prefix member_path = os.path.join(extended_target_dir, member) member_path = _to_extended_length_path(member_path) # Ensure the resolved path is within target_dir to prevent # path traversal attacks (e.g., ../../../etc/malicious). # Use os.path.commonpath to verify both paths share the same root try: common = os.path.commonpath([extended_target_dir, member_path]) if not common.startswith(extended_target_dir): logger.warning(f"Skipping unsafe path in zip: {member}") continue except ValueError: # Paths on different drives (Windows) logger.warning(f"Skipping path on different drive in zip: {member}") continue logger.debug(f"Extracting {member} to {member_path}") # Get ZipInfo for this member to access metadata zip_info = zip_ref.getinfo(member) # Create directories if this is a directory entry if member.endswith("/"): os.makedirs(member_path, exist_ok=True) else: # Ensure parent directory exists parent_dir = os.path.dirname(member_path) if parent_dir: os.makedirs(parent_dir, exist_ok=True) # Extract the file with zip_ref.open(member) as source, open(member_path, "wb") as target: shutil.copyfileobj(source, target) # Preserve file permissions from the zip archive # ZipInfo.external_attr contains Unix file mode in upper 16 bits if zip_info.external_attr: # Extract Unix file mode from external_attr mode = zip_info.external_attr >> 16 if mode: os.chmod(member_path, mode) if remove_top_level_directory: top_level_directory = get_top_level_dir_from_compressed_package(package_path) if top_level_directory is not None: # Remove __MACOSX directory if it exists # Use extended path to handle long paths on Windows macos_dir = _to_extended_length_path( os.path.join(target_dir, MAC_OS_ZIP_HIDDEN_DIR_NAME) ) if os.path.isdir(macos_dir): shutil.rmtree(macos_dir) # Use extended path for cleanup operations remove_dir_from_filepaths(extended_target_dir, top_level_directory) if unlink_zip: Path(package_path).unlink() def delete_package(pkg_uri: str, base_directory: str) -> Tuple[bool, int]: """Deletes a specific URI from the local filesystem. Args: pkg_uri: URI to delete. Returns: bool: True if the URI was successfully deleted, else False. """ deleted = False path = Path(_get_local_path(base_directory, pkg_uri)) with FileLock(str(path) + ".lock"): path = path.with_suffix("") if path.exists(): if path.is_dir() and not path.is_symlink(): shutil.rmtree(str(path)) else: path.unlink() deleted = True return deleted async def install_wheel_package( wheel_uri: str, target_dir: str, logger: Optional[logging.Logger] = default_logger, ) -> None: """Install packages in the wheel URI, and then delete the local wheel file.""" pip_install_cmd = [ "pip", "install", wheel_uri, f"--target={target_dir}", ] logger.info("Running py_modules wheel install command: %s", str(pip_install_cmd)) try: # TODO(architkulkarni): Use `await check_output_cmd` or similar. exit_code, output = exec_cmd_stream_to_logger(pip_install_cmd, logger) finally: wheel_uri_path = Path(wheel_uri) if wheel_uri_path.exists(): if wheel_uri_path.is_dir(): shutil.rmtree(wheel_uri) else: Path(wheel_uri).unlink() if exit_code != 0: if Path(target_dir).exists(): shutil.rmtree(target_dir) raise RuntimeError( f"Failed to install py_modules wheel {wheel_uri}" f"to {target_dir}:\n{output}" )