import atexit import logging from functools import partial from types import FunctionType from typing import Callable, Optional, Type, Union import ray import ray.cloudpickle as pickle from ray.experimental.internal_kv import ( _internal_kv_del, _internal_kv_get, _internal_kv_initialized, _internal_kv_put, ) from ray.tune.error import TuneError from ray.util.annotations import DeveloperAPI TRAINABLE_CLASS = "trainable_class" ENV_CREATOR = "env_creator" RLLIB_MODEL = "rllib_model" RLLIB_PREPROCESSOR = "rllib_preprocessor" RLLIB_ACTION_DIST = "rllib_action_dist" RLLIB_INPUT = "rllib_input" RLLIB_CONNECTOR = "rllib_connector" TEST = "__test__" KNOWN_CATEGORIES = [ TRAINABLE_CLASS, ENV_CREATOR, RLLIB_MODEL, RLLIB_PREPROCESSOR, RLLIB_ACTION_DIST, RLLIB_INPUT, RLLIB_CONNECTOR, TEST, ] logger = logging.getLogger(__name__) def _has_trainable(trainable_name): return _global_registry.contains(TRAINABLE_CLASS, trainable_name) @DeveloperAPI def get_trainable_cls(trainable_name): validate_trainable(trainable_name) return _global_registry.get(TRAINABLE_CLASS, trainable_name) @DeveloperAPI def validate_trainable(trainable_name: str): if not _has_trainable(trainable_name) and not _has_rllib_trainable(trainable_name): raise TuneError(f"Unknown trainable: {trainable_name}") def _has_rllib_trainable(trainable_name: str) -> bool: try: # Make sure everything rllib-related is registered. from ray.rllib import _register_all except (ImportError, ModuleNotFoundError): return False _register_all() return _has_trainable(trainable_name) @DeveloperAPI def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool: """Check if a given trainable is a function trainable. Either the trainable has been wrapped as a FunctionTrainable class already, or it's still a FunctionType/partial/callable.""" from ray.tune.trainable import FunctionTrainable if isinstance(trainable, str): trainable = get_trainable_cls(trainable) is_wrapped_func = isinstance(trainable, type) and issubclass( trainable, FunctionTrainable ) return is_wrapped_func or ( not isinstance(trainable, type) and ( isinstance(trainable, FunctionType) or isinstance(trainable, partial) or callable(trainable) ) ) @DeveloperAPI def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True): """Register a trainable function or class. This enables a class or function to be accessed on every Ray process in the cluster. Args: name: Name to register. trainable: Function or tune.Trainable class. Functions must take (config, status_reporter) as arguments and will be automatically converted into a class during registration. """ from ray.tune.trainable import Trainable, wrap_function if isinstance(trainable, type): logger.debug("Detected class for trainable.") elif isinstance(trainable, FunctionType) or isinstance(trainable, partial): logger.debug("Detected function for trainable.") trainable = wrap_function(trainable) elif callable(trainable): logger.info("Detected unknown callable for trainable. Converting to class.") trainable = wrap_function(trainable) if not issubclass(trainable, Trainable): raise TypeError("Second argument must be convertable to Trainable", trainable) _global_registry.register(TRAINABLE_CLASS, name, trainable) def _unregister_trainables(): _global_registry.unregister_all(TRAINABLE_CLASS) @DeveloperAPI def register_env(name: str, env_creator: Callable): """Register a custom environment for use with RLlib. This enables the environment to be accessed on every Ray process in the cluster. Args: name: Name to register. env_creator: Callable that creates an env. """ if not callable(env_creator): raise TypeError("Second argument must be callable.", env_creator) _global_registry.register(ENV_CREATOR, name, env_creator) def _unregister_envs(): _global_registry.unregister_all(ENV_CREATOR) @DeveloperAPI def register_input(name: str, input_creator: Callable): """Register a custom input api for RLlib. Args: name: Name to register. input_creator: Callable that creates an input reader. """ if not callable(input_creator): raise TypeError("Second argument must be callable.", input_creator) _global_registry.register(RLLIB_INPUT, name, input_creator) def _unregister_inputs(): _global_registry.unregister_all(RLLIB_INPUT) @DeveloperAPI def registry_contains_input(name: str) -> bool: return _global_registry.contains(RLLIB_INPUT, name) @DeveloperAPI def registry_get_input(name: str) -> Callable: return _global_registry.get(RLLIB_INPUT, name) def _unregister_all(): _unregister_inputs() _unregister_envs() _unregister_trainables() def _check_serializability(key, value): _global_registry.register(TEST, key, value) def _make_key(prefix: str, category: str, key: str): """Generate a binary key for the given category and key. Args: prefix: Prefix category: The category of the item key: The unique identifier for the item Returns: The key to use for storing a the value. """ return ( b"TuneRegistry:" + prefix.encode("ascii") + b":" + category.encode("ascii") + b"/" + key.encode("ascii") ) class _Registry: def __init__(self, prefix: Optional[str] = None): """If no prefix is given, use runtime context job ID.""" self._to_flush = {} self._prefix = prefix self._registered = set() self._atexit_handler_registered = False @property def prefix(self): if not self._prefix: self._prefix = ray.get_runtime_context().get_job_id() return self._prefix def _register_atexit(self): if self._atexit_handler_registered: # Already registered return if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE: # Only cleanup on the driver return atexit.register(_unregister_all) self._atexit_handler_registered = True def register(self, category, key, value): """Registers the value with the global registry. Args: category: The category to register under. key: The key to register under. value: The value to register. Raises: PicklingError: If unable to pickle to provided file. """ if category not in KNOWN_CATEGORIES: from ray.tune import TuneError raise TuneError( "Unknown category {} not among {}".format(category, KNOWN_CATEGORIES) ) self._to_flush[(category, key)] = pickle.dumps_debug(value) if _internal_kv_initialized(): self.flush_values() def unregister(self, category, key): if _internal_kv_initialized(): _internal_kv_del(_make_key(self.prefix, category, key)) else: self._to_flush.pop((category, key), None) def unregister_all(self, category: Optional[str] = None): remaining = set() for cat, key in self._registered: if category and category == cat: self.unregister(cat, key) else: remaining.add((cat, key)) self._registered = remaining def contains(self, category, key): if _internal_kv_initialized(): value = _internal_kv_get(_make_key(self.prefix, category, key)) return value is not None else: return (category, key) in self._to_flush def get(self, category, key): if _internal_kv_initialized(): value = _internal_kv_get(_make_key(self.prefix, category, key)) if value is None: raise ValueError( "Registry value for {}/{} doesn't exist.".format(category, key) ) return pickle.loads(value) else: return pickle.loads(self._to_flush[(category, key)]) def flush_values(self): self._register_atexit() for (category, key), value in self._to_flush.items(): _internal_kv_put( _make_key(self.prefix, category, key), value, overwrite=True ) self._registered.add((category, key)) self._to_flush.clear() _global_registry = _Registry() ray._private.worker._post_init_hooks.append(_global_registry.flush_values) class _ParameterRegistry: def __init__(self): self.to_flush = {} self.references = {} def put(self, k, v): self.to_flush[k] = v if ray.is_initialized(): self.flush() def get(self, k): if not ray.is_initialized(): return self.to_flush[k] return ray.get(self.references[k]) def flush(self): for k, v in self.to_flush.items(): if isinstance(v, ray.ObjectRef): self.references[k] = v else: self.references[k] = ray.put(v) self.to_flush.clear()