| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312 |
- import enum
- import os
- from urllib.parse import urlparse
- RAY_RUNTIME_ENV_HTTP_USER_AGENT_ENV_VAR = "RAY_RUNTIME_ENV_HTTP_USER_AGENT"
- RAY_RUNTIME_ENV_BEARER_TOKEN_ENV_VAR = "RAY_RUNTIME_ENV_BEARER_TOKEN"
- _DEFAULT_HTTP_USER_AGENT = "ray-runtime-env-curl/1.0"
- class ProtocolsProvider:
- _MISSING_DEPENDENCIES_WARNING = (
- "Note that these must be preinstalled "
- "on all nodes in the Ray cluster; it is not "
- "sufficient to install them in the runtime_env."
- )
- @classmethod
- def get_protocols(cls):
- return {
- # For packages dynamically uploaded and managed by the GCS.
- "gcs",
- # For conda environments installed locally on each node.
- "conda",
- # For pip environments installed locally on each node.
- "pip",
- # For uv environments install locally on each node.
- "uv",
- # Remote https path, assumes everything packed in one zip file.
- "https",
- # Remote s3 path, assumes everything packed in one zip file.
- "s3",
- # Remote google storage path, assumes everything packed in one zip file.
- "gs",
- # Remote azure blob storage path, assumes everything packed in one zip file.
- "azure",
- # Remote Azure Blob File System Secure path, assumes everything packed in one zip file.
- "abfss",
- # File storage path, assumes everything packed in one zip file.
- "file",
- }
- @classmethod
- def get_remote_protocols(cls):
- return {"https", "s3", "gs", "azure", "abfss", "file"}
- @classmethod
- def _handle_s3_protocol(cls):
- """Set up S3 protocol handling.
- Returns:
- tuple: (open_file function, transport_params)
- Raises:
- ImportError: If required dependencies are not installed.
- """
- try:
- import boto3
- from smart_open import open as open_file
- except ImportError:
- raise ImportError(
- "You must `pip install smart_open[s3]` "
- "to fetch URIs in s3 bucket. " + cls._MISSING_DEPENDENCIES_WARNING
- )
- # Create S3 client, falling back to unsigned for public buckets
- session = boto3.Session()
- # session.get_credentials() will return None if no credentials can be found.
- if session.get_credentials():
- # If credentials are found, use a standard signed client.
- s3_client = session.client("s3")
- else:
- # No credentials found, fall back to an unsigned client for public buckets.
- from botocore import UNSIGNED
- from botocore.config import Config
- s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
- transport_params = {"client": s3_client}
- return open_file, transport_params
- @classmethod
- def _handle_gs_protocol(cls):
- """Set up Google Cloud Storage protocol handling.
- Returns:
- tuple: (open_file function, transport_params)
- Raises:
- ImportError: If required dependencies are not installed.
- """
- try:
- from google.cloud import storage # noqa: F401
- from smart_open import open as open_file
- except ImportError:
- raise ImportError(
- "You must `pip install smart_open[gcs]` "
- "to fetch URIs in Google Cloud Storage bucket."
- + cls._MISSING_DEPENDENCIES_WARNING
- )
- return open_file, None
- @classmethod
- def _handle_azure_protocol(cls):
- """Set up Azure blob storage protocol handling.
- Returns:
- tuple: (open_file function, transport_params)
- Raises:
- ImportError: If required dependencies are not installed.
- ValueError: If required environment variables are not set.
- """
- try:
- from azure.identity import DefaultAzureCredential
- from azure.storage.blob import BlobServiceClient # noqa: F401
- from smart_open import open as open_file
- except ImportError:
- raise ImportError(
- "You must `pip install azure-storage-blob azure-identity smart_open[azure]` "
- "to fetch URIs in Azure Blob Storage. "
- + cls._MISSING_DEPENDENCIES_WARNING
- )
- # Define authentication variable
- azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT")
- if not azure_storage_account_name:
- raise ValueError(
- "Azure Blob Storage authentication requires "
- "AZURE_STORAGE_ACCOUNT environment variable to be set."
- )
- account_url = f"https://{azure_storage_account_name}.blob.core.windows.net/"
- transport_params = {
- "client": BlobServiceClient(
- account_url=account_url, credential=DefaultAzureCredential()
- )
- }
- return open_file, transport_params
- @classmethod
- def _handle_abfss_protocol(cls):
- """Set up Azure Blob File System Secure (ABFSS) protocol handling.
- Returns:
- tuple: (open_file function, transport_params)
- Raises:
- ImportError: If required dependencies are not installed.
- ValueError: If the ABFSS URI format is invalid.
- """
- try:
- import adlfs
- from azure.identity import DefaultAzureCredential
- except ImportError:
- raise ImportError(
- "You must `pip install adlfs azure-identity` "
- "to fetch URIs in Azure Blob File System Secure. "
- + cls._MISSING_DEPENDENCIES_WARNING
- )
- def open_file(uri, mode, *, transport_params=None):
- # Parse and validate the ABFSS URI
- parsed = urlparse(uri)
- # Validate ABFSS URI format: abfss://container@account.dfs.core.windows.net/path
- if not parsed.netloc or "@" not in parsed.netloc:
- raise ValueError(
- f"Invalid ABFSS URI format - missing container@account: {uri}"
- )
- container_part, hostname_part = parsed.netloc.split("@", 1)
- # Validate container name (must be non-empty)
- if not container_part:
- raise ValueError(
- f"Invalid ABFSS URI format - empty container name: {uri}"
- )
- # Validate hostname format
- if not hostname_part or not hostname_part.endswith(".dfs.core.windows.net"):
- raise ValueError(
- f"Invalid ABFSS URI format - invalid hostname (must end with .dfs.core.windows.net): {uri}"
- )
- # Extract and validate account name
- azure_storage_account_name = hostname_part.split(".")[0]
- if not azure_storage_account_name:
- raise ValueError(
- f"Invalid ABFSS URI format - empty account name: {uri}"
- )
- # Handle ABFSS URI with adlfs
- filesystem = adlfs.AzureBlobFileSystem(
- account_name=azure_storage_account_name,
- credential=DefaultAzureCredential(),
- )
- return filesystem.open(uri, mode)
- return open_file, None
- @classmethod
- def _http_headers(cls) -> dict:
- headers = {
- "User-Agent": os.environ.get(
- RAY_RUNTIME_ENV_HTTP_USER_AGENT_ENV_VAR, _DEFAULT_HTTP_USER_AGENT
- ),
- "Accept": "*/*",
- }
- bearer_token = os.environ.get(RAY_RUNTIME_ENV_BEARER_TOKEN_ENV_VAR)
- if bearer_token:
- headers["Authorization"] = f"Bearer {bearer_token}"
- return headers
- @classmethod
- def _handle_https_protocol(cls):
- """Set up HTTPS protocol handling with curl-like headers."""
- try:
- from smart_open import open as smart_open_open
- except ImportError:
- raise ImportError(
- "You must `pip install smart_open` to fetch HTTPS URIs. "
- + cls._MISSING_DEPENDENCIES_WARNING
- )
- def open_file(uri, mode, *, transport_params=None):
- params = {
- "headers": cls._http_headers(),
- "timeout": 60,
- }
- if transport_params:
- params.update(transport_params)
- return smart_open_open(uri, mode, transport_params=params)
- return open_file, None
- @classmethod
- def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str):
- """Download file from remote URI to destination file.
- Args:
- protocol: The protocol to use for downloading (e.g., 's3', 'https').
- source_uri: The source URI to download from.
- dest_file: The destination file path to save to.
- Raises:
- ImportError: If required dependencies for the protocol are not installed.
- """
- assert protocol in cls.get_remote_protocols()
- tp = None
- open_file = None
- if protocol == "file":
- source_uri = source_uri[len("file://") :]
- def open_file(uri, mode, *, transport_params=None):
- return open(uri, mode)
- elif protocol == "https":
- open_file, tp = cls._handle_https_protocol()
- elif protocol == "s3":
- open_file, tp = cls._handle_s3_protocol()
- elif protocol == "gs":
- open_file, tp = cls._handle_gs_protocol()
- elif protocol == "azure":
- open_file, tp = cls._handle_azure_protocol()
- elif protocol == "abfss":
- open_file, tp = cls._handle_abfss_protocol()
- else:
- try:
- from smart_open import open as open_file
- except ImportError:
- raise ImportError(
- "You must `pip install smart_open` "
- f"to fetch {protocol.upper()} URIs. "
- + cls._MISSING_DEPENDENCIES_WARNING
- )
- with open_file(source_uri, "rb", transport_params=tp) as fin:
- with open(dest_file, "wb") as fout:
- fout.write(fin.read())
- Protocol = enum.Enum(
- "Protocol",
- {protocol.upper(): protocol for protocol in ProtocolsProvider.get_protocols()},
- )
- @classmethod
- def _remote_protocols(cls):
- # Returns a list of protocols that support remote storage
- # These protocols should only be used with paths that end in ".zip" or ".whl"
- return [
- cls[protocol.upper()] for protocol in ProtocolsProvider.get_remote_protocols()
- ]
- Protocol.remote_protocols = _remote_protocols
- def _download_remote_uri(self, source_uri, dest_file):
- return ProtocolsProvider.download_remote_uri(self.value, source_uri, dest_file)
- Protocol.download_remote_uri = _download_remote_uri
|