protocol.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. import enum
  2. import os
  3. from urllib.parse import urlparse
  4. RAY_RUNTIME_ENV_HTTP_USER_AGENT_ENV_VAR = "RAY_RUNTIME_ENV_HTTP_USER_AGENT"
  5. RAY_RUNTIME_ENV_BEARER_TOKEN_ENV_VAR = "RAY_RUNTIME_ENV_BEARER_TOKEN"
  6. _DEFAULT_HTTP_USER_AGENT = "ray-runtime-env-curl/1.0"
  7. class ProtocolsProvider:
  8. _MISSING_DEPENDENCIES_WARNING = (
  9. "Note that these must be preinstalled "
  10. "on all nodes in the Ray cluster; it is not "
  11. "sufficient to install them in the runtime_env."
  12. )
  13. @classmethod
  14. def get_protocols(cls):
  15. return {
  16. # For packages dynamically uploaded and managed by the GCS.
  17. "gcs",
  18. # For conda environments installed locally on each node.
  19. "conda",
  20. # For pip environments installed locally on each node.
  21. "pip",
  22. # For uv environments install locally on each node.
  23. "uv",
  24. # Remote https path, assumes everything packed in one zip file.
  25. "https",
  26. # Remote s3 path, assumes everything packed in one zip file.
  27. "s3",
  28. # Remote google storage path, assumes everything packed in one zip file.
  29. "gs",
  30. # Remote azure blob storage path, assumes everything packed in one zip file.
  31. "azure",
  32. # Remote Azure Blob File System Secure path, assumes everything packed in one zip file.
  33. "abfss",
  34. # File storage path, assumes everything packed in one zip file.
  35. "file",
  36. }
  37. @classmethod
  38. def get_remote_protocols(cls):
  39. return {"https", "s3", "gs", "azure", "abfss", "file"}
  40. @classmethod
  41. def _handle_s3_protocol(cls):
  42. """Set up S3 protocol handling.
  43. Returns:
  44. tuple: (open_file function, transport_params)
  45. Raises:
  46. ImportError: If required dependencies are not installed.
  47. """
  48. try:
  49. import boto3
  50. from smart_open import open as open_file
  51. except ImportError:
  52. raise ImportError(
  53. "You must `pip install smart_open[s3]` "
  54. "to fetch URIs in s3 bucket. " + cls._MISSING_DEPENDENCIES_WARNING
  55. )
  56. # Create S3 client, falling back to unsigned for public buckets
  57. session = boto3.Session()
  58. # session.get_credentials() will return None if no credentials can be found.
  59. if session.get_credentials():
  60. # If credentials are found, use a standard signed client.
  61. s3_client = session.client("s3")
  62. else:
  63. # No credentials found, fall back to an unsigned client for public buckets.
  64. from botocore import UNSIGNED
  65. from botocore.config import Config
  66. s3_client = boto3.client("s3", config=Config(signature_version=UNSIGNED))
  67. transport_params = {"client": s3_client}
  68. return open_file, transport_params
  69. @classmethod
  70. def _handle_gs_protocol(cls):
  71. """Set up Google Cloud Storage protocol handling.
  72. Returns:
  73. tuple: (open_file function, transport_params)
  74. Raises:
  75. ImportError: If required dependencies are not installed.
  76. """
  77. try:
  78. from google.cloud import storage # noqa: F401
  79. from smart_open import open as open_file
  80. except ImportError:
  81. raise ImportError(
  82. "You must `pip install smart_open[gcs]` "
  83. "to fetch URIs in Google Cloud Storage bucket."
  84. + cls._MISSING_DEPENDENCIES_WARNING
  85. )
  86. return open_file, None
  87. @classmethod
  88. def _handle_azure_protocol(cls):
  89. """Set up Azure blob storage protocol handling.
  90. Returns:
  91. tuple: (open_file function, transport_params)
  92. Raises:
  93. ImportError: If required dependencies are not installed.
  94. ValueError: If required environment variables are not set.
  95. """
  96. try:
  97. from azure.identity import DefaultAzureCredential
  98. from azure.storage.blob import BlobServiceClient # noqa: F401
  99. from smart_open import open as open_file
  100. except ImportError:
  101. raise ImportError(
  102. "You must `pip install azure-storage-blob azure-identity smart_open[azure]` "
  103. "to fetch URIs in Azure Blob Storage. "
  104. + cls._MISSING_DEPENDENCIES_WARNING
  105. )
  106. # Define authentication variable
  107. azure_storage_account_name = os.getenv("AZURE_STORAGE_ACCOUNT")
  108. if not azure_storage_account_name:
  109. raise ValueError(
  110. "Azure Blob Storage authentication requires "
  111. "AZURE_STORAGE_ACCOUNT environment variable to be set."
  112. )
  113. account_url = f"https://{azure_storage_account_name}.blob.core.windows.net/"
  114. transport_params = {
  115. "client": BlobServiceClient(
  116. account_url=account_url, credential=DefaultAzureCredential()
  117. )
  118. }
  119. return open_file, transport_params
  120. @classmethod
  121. def _handle_abfss_protocol(cls):
  122. """Set up Azure Blob File System Secure (ABFSS) protocol handling.
  123. Returns:
  124. tuple: (open_file function, transport_params)
  125. Raises:
  126. ImportError: If required dependencies are not installed.
  127. ValueError: If the ABFSS URI format is invalid.
  128. """
  129. try:
  130. import adlfs
  131. from azure.identity import DefaultAzureCredential
  132. except ImportError:
  133. raise ImportError(
  134. "You must `pip install adlfs azure-identity` "
  135. "to fetch URIs in Azure Blob File System Secure. "
  136. + cls._MISSING_DEPENDENCIES_WARNING
  137. )
  138. def open_file(uri, mode, *, transport_params=None):
  139. # Parse and validate the ABFSS URI
  140. parsed = urlparse(uri)
  141. # Validate ABFSS URI format: abfss://container@account.dfs.core.windows.net/path
  142. if not parsed.netloc or "@" not in parsed.netloc:
  143. raise ValueError(
  144. f"Invalid ABFSS URI format - missing container@account: {uri}"
  145. )
  146. container_part, hostname_part = parsed.netloc.split("@", 1)
  147. # Validate container name (must be non-empty)
  148. if not container_part:
  149. raise ValueError(
  150. f"Invalid ABFSS URI format - empty container name: {uri}"
  151. )
  152. # Validate hostname format
  153. if not hostname_part or not hostname_part.endswith(".dfs.core.windows.net"):
  154. raise ValueError(
  155. f"Invalid ABFSS URI format - invalid hostname (must end with .dfs.core.windows.net): {uri}"
  156. )
  157. # Extract and validate account name
  158. azure_storage_account_name = hostname_part.split(".")[0]
  159. if not azure_storage_account_name:
  160. raise ValueError(
  161. f"Invalid ABFSS URI format - empty account name: {uri}"
  162. )
  163. # Handle ABFSS URI with adlfs
  164. filesystem = adlfs.AzureBlobFileSystem(
  165. account_name=azure_storage_account_name,
  166. credential=DefaultAzureCredential(),
  167. )
  168. return filesystem.open(uri, mode)
  169. return open_file, None
  170. @classmethod
  171. def _http_headers(cls) -> dict:
  172. headers = {
  173. "User-Agent": os.environ.get(
  174. RAY_RUNTIME_ENV_HTTP_USER_AGENT_ENV_VAR, _DEFAULT_HTTP_USER_AGENT
  175. ),
  176. "Accept": "*/*",
  177. }
  178. bearer_token = os.environ.get(RAY_RUNTIME_ENV_BEARER_TOKEN_ENV_VAR)
  179. if bearer_token:
  180. headers["Authorization"] = f"Bearer {bearer_token}"
  181. return headers
  182. @classmethod
  183. def _handle_https_protocol(cls):
  184. """Set up HTTPS protocol handling with curl-like headers."""
  185. try:
  186. from smart_open import open as smart_open_open
  187. except ImportError:
  188. raise ImportError(
  189. "You must `pip install smart_open` to fetch HTTPS URIs. "
  190. + cls._MISSING_DEPENDENCIES_WARNING
  191. )
  192. def open_file(uri, mode, *, transport_params=None):
  193. params = {
  194. "headers": cls._http_headers(),
  195. "timeout": 60,
  196. }
  197. if transport_params:
  198. params.update(transport_params)
  199. return smart_open_open(uri, mode, transport_params=params)
  200. return open_file, None
  201. @classmethod
  202. def download_remote_uri(cls, protocol: str, source_uri: str, dest_file: str):
  203. """Download file from remote URI to destination file.
  204. Args:
  205. protocol: The protocol to use for downloading (e.g., 's3', 'https').
  206. source_uri: The source URI to download from.
  207. dest_file: The destination file path to save to.
  208. Raises:
  209. ImportError: If required dependencies for the protocol are not installed.
  210. """
  211. assert protocol in cls.get_remote_protocols()
  212. tp = None
  213. open_file = None
  214. if protocol == "file":
  215. source_uri = source_uri[len("file://") :]
  216. def open_file(uri, mode, *, transport_params=None):
  217. return open(uri, mode)
  218. elif protocol == "https":
  219. open_file, tp = cls._handle_https_protocol()
  220. elif protocol == "s3":
  221. open_file, tp = cls._handle_s3_protocol()
  222. elif protocol == "gs":
  223. open_file, tp = cls._handle_gs_protocol()
  224. elif protocol == "azure":
  225. open_file, tp = cls._handle_azure_protocol()
  226. elif protocol == "abfss":
  227. open_file, tp = cls._handle_abfss_protocol()
  228. else:
  229. try:
  230. from smart_open import open as open_file
  231. except ImportError:
  232. raise ImportError(
  233. "You must `pip install smart_open` "
  234. f"to fetch {protocol.upper()} URIs. "
  235. + cls._MISSING_DEPENDENCIES_WARNING
  236. )
  237. with open_file(source_uri, "rb", transport_params=tp) as fin:
  238. with open(dest_file, "wb") as fout:
  239. fout.write(fin.read())
  240. Protocol = enum.Enum(
  241. "Protocol",
  242. {protocol.upper(): protocol for protocol in ProtocolsProvider.get_protocols()},
  243. )
  244. @classmethod
  245. def _remote_protocols(cls):
  246. # Returns a list of protocols that support remote storage
  247. # These protocols should only be used with paths that end in ".zip" or ".whl"
  248. return [
  249. cls[protocol.upper()] for protocol in ProtocolsProvider.get_remote_protocols()
  250. ]
  251. Protocol.remote_protocols = _remote_protocols
  252. def _download_remote_uri(self, source_uri, dest_file):
  253. return ProtocolsProvider.download_remote_uri(self.value, source_uri, dest_file)
  254. Protocol.download_remote_uri = _download_remote_uri