| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- import json
- import logging
- import os
- from abc import ABC
- from typing import Any, Dict, List, Optional, Type
- from ray._common.utils import import_attr
- from ray._private.runtime_env.constants import (
- RAY_RUNTIME_ENV_CLASS_FIELD_NAME,
- RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY,
- RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY,
- RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY,
- RAY_RUNTIME_ENV_PLUGINS_ENV_VAR,
- RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME,
- )
- from ray._private.runtime_env.context import RuntimeEnvContext
- from ray._private.runtime_env.uri_cache import URICache
- from ray.util.annotations import DeveloperAPI
- default_logger = logging.getLogger(__name__)
- @DeveloperAPI
- class RuntimeEnvPlugin(ABC):
- """Abstract base class for runtime environment plugins."""
- name: str = None
- priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY
- @staticmethod
- def validate(runtime_env_dict: dict) -> None:
- """Validate user entry for this plugin.
- The method is invoked upon installation of runtime env.
- Args:
- runtime_env_dict: The user-supplied runtime environment dict.
- Raises:
- ValueError: If the validation fails.
- """
- pass
- def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
- return []
- async def create(
- self,
- uri: Optional[str],
- runtime_env,
- context: RuntimeEnvContext,
- logger: logging.Logger,
- ) -> float:
- """Create and install the runtime environment.
- Gets called in the runtime env agent at install time. The URI can be
- used as a caching mechanism.
- Args:
- uri: A URI uniquely describing this resource.
- runtime_env: The RuntimeEnv object.
- context: Auxiliary information supplied by Ray.
- logger: A logger to log messages during the context modification.
- Returns:
- float: The disk space taken up by this plugin installation for this
- environment. e.g. for working_dir, this downloads the files to the
- local node.
- """
- return 0
- def modify_context(
- self,
- uris: List[str],
- runtime_env: "RuntimeEnv", # noqa: F821
- context: RuntimeEnvContext,
- logger: logging.Logger,
- ) -> None:
- """Modify context to change worker startup behavior.
- For example, you can use this to prepend "cd <dir>" command to worker
- startup, or add new environment variables.
- Args:
- uris: The URIs used by this resource.
- runtime_env: The RuntimeEnv object.
- context: Auxiliary information supplied by Ray.
- logger: A logger to log messages during the context modification.
- """
- return
- def delete_uri(self, uri: str, logger: logging.Logger) -> float:
- """Delete the runtime environment given uri.
- Args:
- uri: A URI uniquely describing this resource.
- logger: The logger used to log messages during the deletion.
- Returns:
- float: The amount of space reclaimed by the deletion.
- """
- return 0
- class PluginSetupContext:
- def __init__(
- self,
- name: str,
- class_instance: RuntimeEnvPlugin,
- priority: int,
- uri_cache: URICache,
- ):
- self.name = name
- self.class_instance = class_instance
- self.priority = priority
- self.uri_cache = uri_cache
- class RuntimeEnvPluginManager:
- """This manager is used to load plugins in runtime env agent."""
- def __init__(self):
- self.plugins: Dict[str, PluginSetupContext] = {}
- plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
- if plugin_config_str:
- plugin_configs = json.loads(plugin_config_str)
- self.load_plugins(plugin_configs)
- def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None:
- if not issubclass(plugin_class, RuntimeEnvPlugin):
- raise RuntimeError(
- f"Invalid runtime env plugin class {plugin_class}. "
- "The plugin class must inherit "
- "ray._private.runtime_env.plugin.RuntimeEnvPlugin."
- )
- if not plugin_class.name:
- raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.")
- if plugin_class.name in self.plugins:
- raise RuntimeError(
- f"The name of runtime env plugin {plugin_class} conflicts "
- f"with {self.plugins[plugin_class.name]}.",
- )
- def validate_priority(self, priority: Any) -> None:
- if (
- not isinstance(priority, int)
- or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
- or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
- ):
- raise RuntimeError(
- f"Invalid runtime env priority {priority}, "
- "it should be an integer between "
- f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
- f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
- )
- def load_plugins(self, plugin_configs: List[Dict]) -> None:
- """Load runtime env plugins and create URI caches for them."""
- for plugin_config in plugin_configs:
- if (
- not isinstance(plugin_config, dict)
- or RAY_RUNTIME_ENV_CLASS_FIELD_NAME not in plugin_config
- ):
- raise RuntimeError(
- f"Invalid runtime env plugin config {plugin_config}, "
- "it should be a object which contains the "
- f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field."
- )
- plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME])
- self.validate_plugin_class(plugin_class)
- # The priority should be an integer between 0 and 100.
- # The default priority is 10. A smaller number indicates a
- # higher priority and the plugin will be set up first.
- if RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME in plugin_config:
- priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME]
- else:
- priority = plugin_class.priority
- self.validate_priority(priority)
- class_instance = plugin_class()
- self.plugins[plugin_class.name] = PluginSetupContext(
- plugin_class.name,
- class_instance,
- priority,
- self.create_uri_cache_for_plugin(class_instance),
- )
- def add_plugin(self, plugin: RuntimeEnvPlugin) -> None:
- """Add a plugin to the manager and create a URI cache for it.
- Args:
- plugin: The class instance of the plugin.
- """
- plugin_class = type(plugin)
- self.validate_plugin_class(plugin_class)
- self.validate_priority(plugin_class.priority)
- self.plugins[plugin_class.name] = PluginSetupContext(
- plugin_class.name,
- plugin,
- plugin_class.priority,
- self.create_uri_cache_for_plugin(plugin),
- )
- def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache:
- """Create a URI cache for a plugin.
- Args:
- plugin_name: The name of the plugin.
- Returns:
- The created URI cache for the plugin.
- """
- # Set the max size for the cache. Defaults to 10 GB.
- cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
- cache_size_bytes = int(
- (1024**3) * float(os.environ.get(cache_size_env_var, 10))
- )
- return URICache(plugin.delete_uri, cache_size_bytes)
- def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]:
- """Get the sorted plugin setup contexts, sorted by increasing priority.
- Returns:
- The sorted plugin setup contexts.
- """
- return sorted(self.plugins.values(), key=lambda x: x.priority)
- async def create_for_plugin_if_needed(
- runtime_env: "RuntimeEnv", # noqa: F821
- plugin: RuntimeEnvPlugin,
- uri_cache: URICache,
- context: RuntimeEnvContext,
- logger: logging.Logger = default_logger,
- ):
- """Set up the environment using the plugin if not already set up and cached."""
- if plugin.name not in runtime_env or runtime_env[plugin.name] is None:
- return
- plugin.validate(runtime_env)
- uris = plugin.get_uris(runtime_env)
- if not uris:
- logger.debug(
- f"No URIs for runtime env plugin {plugin.name}; "
- "create always without checking the cache."
- )
- await plugin.create(None, runtime_env, context, logger=logger)
- for uri in uris:
- if uri not in uri_cache:
- logger.debug(f"Cache miss for URI {uri}.")
- size_bytes = await plugin.create(uri, runtime_env, context, logger=logger)
- uri_cache.add(uri, size_bytes, logger=logger)
- else:
- logger.info(
- f"Runtime env {plugin.name} {uri} is already installed "
- "and will be reused. Search "
- "all runtime_env_setup-*.log to find the corresponding setup log."
- )
- uri_cache.mark_used(uri, logger=logger)
- plugin.modify_context(uris, runtime_env, context, logger)
|