plugin.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. import json
  2. import logging
  3. import os
  4. from abc import ABC
  5. from typing import Any, Dict, List, Optional, Type
  6. from ray._common.utils import import_attr
  7. from ray._private.runtime_env.constants import (
  8. RAY_RUNTIME_ENV_CLASS_FIELD_NAME,
  9. RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY,
  10. RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY,
  11. RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY,
  12. RAY_RUNTIME_ENV_PLUGINS_ENV_VAR,
  13. RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME,
  14. )
  15. from ray._private.runtime_env.context import RuntimeEnvContext
  16. from ray._private.runtime_env.uri_cache import URICache
  17. from ray.util.annotations import DeveloperAPI
  18. default_logger = logging.getLogger(__name__)
  19. @DeveloperAPI
  20. class RuntimeEnvPlugin(ABC):
  21. """Abstract base class for runtime environment plugins."""
  22. name: str = None
  23. priority: int = RAY_RUNTIME_ENV_PLUGIN_DEFAULT_PRIORITY
  24. @staticmethod
  25. def validate(runtime_env_dict: dict) -> None:
  26. """Validate user entry for this plugin.
  27. The method is invoked upon installation of runtime env.
  28. Args:
  29. runtime_env_dict: The user-supplied runtime environment dict.
  30. Raises:
  31. ValueError: If the validation fails.
  32. """
  33. pass
  34. def get_uris(self, runtime_env: "RuntimeEnv") -> List[str]: # noqa: F821
  35. return []
  36. async def create(
  37. self,
  38. uri: Optional[str],
  39. runtime_env,
  40. context: RuntimeEnvContext,
  41. logger: logging.Logger,
  42. ) -> float:
  43. """Create and install the runtime environment.
  44. Gets called in the runtime env agent at install time. The URI can be
  45. used as a caching mechanism.
  46. Args:
  47. uri: A URI uniquely describing this resource.
  48. runtime_env: The RuntimeEnv object.
  49. context: Auxiliary information supplied by Ray.
  50. logger: A logger to log messages during the context modification.
  51. Returns:
  52. float: The disk space taken up by this plugin installation for this
  53. environment. e.g. for working_dir, this downloads the files to the
  54. local node.
  55. """
  56. return 0
  57. def modify_context(
  58. self,
  59. uris: List[str],
  60. runtime_env: "RuntimeEnv", # noqa: F821
  61. context: RuntimeEnvContext,
  62. logger: logging.Logger,
  63. ) -> None:
  64. """Modify context to change worker startup behavior.
  65. For example, you can use this to prepend "cd <dir>" command to worker
  66. startup, or add new environment variables.
  67. Args:
  68. uris: The URIs used by this resource.
  69. runtime_env: The RuntimeEnv object.
  70. context: Auxiliary information supplied by Ray.
  71. logger: A logger to log messages during the context modification.
  72. """
  73. return
  74. def delete_uri(self, uri: str, logger: logging.Logger) -> float:
  75. """Delete the runtime environment given uri.
  76. Args:
  77. uri: A URI uniquely describing this resource.
  78. logger: The logger used to log messages during the deletion.
  79. Returns:
  80. float: The amount of space reclaimed by the deletion.
  81. """
  82. return 0
  83. class PluginSetupContext:
  84. def __init__(
  85. self,
  86. name: str,
  87. class_instance: RuntimeEnvPlugin,
  88. priority: int,
  89. uri_cache: URICache,
  90. ):
  91. self.name = name
  92. self.class_instance = class_instance
  93. self.priority = priority
  94. self.uri_cache = uri_cache
  95. class RuntimeEnvPluginManager:
  96. """This manager is used to load plugins in runtime env agent."""
  97. def __init__(self):
  98. self.plugins: Dict[str, PluginSetupContext] = {}
  99. plugin_config_str = os.environ.get(RAY_RUNTIME_ENV_PLUGINS_ENV_VAR)
  100. if plugin_config_str:
  101. plugin_configs = json.loads(plugin_config_str)
  102. self.load_plugins(plugin_configs)
  103. def validate_plugin_class(self, plugin_class: Type[RuntimeEnvPlugin]) -> None:
  104. if not issubclass(plugin_class, RuntimeEnvPlugin):
  105. raise RuntimeError(
  106. f"Invalid runtime env plugin class {plugin_class}. "
  107. "The plugin class must inherit "
  108. "ray._private.runtime_env.plugin.RuntimeEnvPlugin."
  109. )
  110. if not plugin_class.name:
  111. raise RuntimeError(f"No valid name in runtime env plugin {plugin_class}.")
  112. if plugin_class.name in self.plugins:
  113. raise RuntimeError(
  114. f"The name of runtime env plugin {plugin_class} conflicts "
  115. f"with {self.plugins[plugin_class.name]}.",
  116. )
  117. def validate_priority(self, priority: Any) -> None:
  118. if (
  119. not isinstance(priority, int)
  120. or priority < RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY
  121. or priority > RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY
  122. ):
  123. raise RuntimeError(
  124. f"Invalid runtime env priority {priority}, "
  125. "it should be an integer between "
  126. f"{RAY_RUNTIME_ENV_PLUGIN_MIN_PRIORITY} "
  127. f"and {RAY_RUNTIME_ENV_PLUGIN_MAX_PRIORITY}."
  128. )
  129. def load_plugins(self, plugin_configs: List[Dict]) -> None:
  130. """Load runtime env plugins and create URI caches for them."""
  131. for plugin_config in plugin_configs:
  132. if (
  133. not isinstance(plugin_config, dict)
  134. or RAY_RUNTIME_ENV_CLASS_FIELD_NAME not in plugin_config
  135. ):
  136. raise RuntimeError(
  137. f"Invalid runtime env plugin config {plugin_config}, "
  138. "it should be a object which contains the "
  139. f"{RAY_RUNTIME_ENV_CLASS_FIELD_NAME} field."
  140. )
  141. plugin_class = import_attr(plugin_config[RAY_RUNTIME_ENV_CLASS_FIELD_NAME])
  142. self.validate_plugin_class(plugin_class)
  143. # The priority should be an integer between 0 and 100.
  144. # The default priority is 10. A smaller number indicates a
  145. # higher priority and the plugin will be set up first.
  146. if RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME in plugin_config:
  147. priority = plugin_config[RAY_RUNTIME_ENV_PRIORITY_FIELD_NAME]
  148. else:
  149. priority = plugin_class.priority
  150. self.validate_priority(priority)
  151. class_instance = plugin_class()
  152. self.plugins[plugin_class.name] = PluginSetupContext(
  153. plugin_class.name,
  154. class_instance,
  155. priority,
  156. self.create_uri_cache_for_plugin(class_instance),
  157. )
  158. def add_plugin(self, plugin: RuntimeEnvPlugin) -> None:
  159. """Add a plugin to the manager and create a URI cache for it.
  160. Args:
  161. plugin: The class instance of the plugin.
  162. """
  163. plugin_class = type(plugin)
  164. self.validate_plugin_class(plugin_class)
  165. self.validate_priority(plugin_class.priority)
  166. self.plugins[plugin_class.name] = PluginSetupContext(
  167. plugin_class.name,
  168. plugin,
  169. plugin_class.priority,
  170. self.create_uri_cache_for_plugin(plugin),
  171. )
  172. def create_uri_cache_for_plugin(self, plugin: RuntimeEnvPlugin) -> URICache:
  173. """Create a URI cache for a plugin.
  174. Args:
  175. plugin_name: The name of the plugin.
  176. Returns:
  177. The created URI cache for the plugin.
  178. """
  179. # Set the max size for the cache. Defaults to 10 GB.
  180. cache_size_env_var = f"RAY_RUNTIME_ENV_{plugin.name}_CACHE_SIZE_GB".upper()
  181. cache_size_bytes = int(
  182. (1024**3) * float(os.environ.get(cache_size_env_var, 10))
  183. )
  184. return URICache(plugin.delete_uri, cache_size_bytes)
  185. def sorted_plugin_setup_contexts(self) -> List[PluginSetupContext]:
  186. """Get the sorted plugin setup contexts, sorted by increasing priority.
  187. Returns:
  188. The sorted plugin setup contexts.
  189. """
  190. return sorted(self.plugins.values(), key=lambda x: x.priority)
  191. async def create_for_plugin_if_needed(
  192. runtime_env: "RuntimeEnv", # noqa: F821
  193. plugin: RuntimeEnvPlugin,
  194. uri_cache: URICache,
  195. context: RuntimeEnvContext,
  196. logger: logging.Logger = default_logger,
  197. ):
  198. """Set up the environment using the plugin if not already set up and cached."""
  199. if plugin.name not in runtime_env or runtime_env[plugin.name] is None:
  200. return
  201. plugin.validate(runtime_env)
  202. uris = plugin.get_uris(runtime_env)
  203. if not uris:
  204. logger.debug(
  205. f"No URIs for runtime env plugin {plugin.name}; "
  206. "create always without checking the cache."
  207. )
  208. await plugin.create(None, runtime_env, context, logger=logger)
  209. for uri in uris:
  210. if uri not in uri_cache:
  211. logger.debug(f"Cache miss for URI {uri}.")
  212. size_bytes = await plugin.create(uri, runtime_env, context, logger=logger)
  213. uri_cache.add(uri, size_bytes, logger=logger)
  214. else:
  215. logger.info(
  216. f"Runtime env {plugin.name} {uri} is already installed "
  217. "and will be reused. Search "
  218. "all runtime_env_setup-*.log to find the corresponding setup log."
  219. )
  220. uri_cache.mark_used(uri, logger=logger)
  221. plugin.modify_context(uris, runtime_env, context, logger)