| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- import atexit
- import logging
- from functools import partial
- from types import FunctionType
- from typing import Callable, Optional, Type, Union
- import ray
- import ray.cloudpickle as pickle
- from ray.experimental.internal_kv import (
- _internal_kv_del,
- _internal_kv_get,
- _internal_kv_initialized,
- _internal_kv_put,
- )
- from ray.tune.error import TuneError
- from ray.util.annotations import DeveloperAPI
- TRAINABLE_CLASS = "trainable_class"
- ENV_CREATOR = "env_creator"
- RLLIB_MODEL = "rllib_model"
- RLLIB_PREPROCESSOR = "rllib_preprocessor"
- RLLIB_ACTION_DIST = "rllib_action_dist"
- RLLIB_INPUT = "rllib_input"
- RLLIB_CONNECTOR = "rllib_connector"
- TEST = "__test__"
- KNOWN_CATEGORIES = [
- TRAINABLE_CLASS,
- ENV_CREATOR,
- RLLIB_MODEL,
- RLLIB_PREPROCESSOR,
- RLLIB_ACTION_DIST,
- RLLIB_INPUT,
- RLLIB_CONNECTOR,
- TEST,
- ]
- logger = logging.getLogger(__name__)
- def _has_trainable(trainable_name):
- return _global_registry.contains(TRAINABLE_CLASS, trainable_name)
- @DeveloperAPI
- def get_trainable_cls(trainable_name):
- validate_trainable(trainable_name)
- return _global_registry.get(TRAINABLE_CLASS, trainable_name)
- @DeveloperAPI
- def validate_trainable(trainable_name: str):
- if not _has_trainable(trainable_name) and not _has_rllib_trainable(trainable_name):
- raise TuneError(f"Unknown trainable: {trainable_name}")
- def _has_rllib_trainable(trainable_name: str) -> bool:
- try:
- # Make sure everything rllib-related is registered.
- from ray.rllib import _register_all
- except (ImportError, ModuleNotFoundError):
- return False
- _register_all()
- return _has_trainable(trainable_name)
- @DeveloperAPI
- def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool:
- """Check if a given trainable is a function trainable.
- Either the trainable has been wrapped as a FunctionTrainable class already,
- or it's still a FunctionType/partial/callable."""
- from ray.tune.trainable import FunctionTrainable
- if isinstance(trainable, str):
- trainable = get_trainable_cls(trainable)
- is_wrapped_func = isinstance(trainable, type) and issubclass(
- trainable, FunctionTrainable
- )
- return is_wrapped_func or (
- not isinstance(trainable, type)
- and (
- isinstance(trainable, FunctionType)
- or isinstance(trainable, partial)
- or callable(trainable)
- )
- )
- @DeveloperAPI
- def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True):
- """Register a trainable function or class.
- This enables a class or function to be accessed on every Ray process
- in the cluster.
- Args:
- name: Name to register.
- trainable: Function or tune.Trainable class. Functions must
- take (config, status_reporter) as arguments and will be
- automatically converted into a class during registration.
- """
- from ray.tune.trainable import Trainable, wrap_function
- if isinstance(trainable, type):
- logger.debug("Detected class for trainable.")
- elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
- logger.debug("Detected function for trainable.")
- trainable = wrap_function(trainable)
- elif callable(trainable):
- logger.info("Detected unknown callable for trainable. Converting to class.")
- trainable = wrap_function(trainable)
- if not issubclass(trainable, Trainable):
- raise TypeError("Second argument must be convertable to Trainable", trainable)
- _global_registry.register(TRAINABLE_CLASS, name, trainable)
- def _unregister_trainables():
- _global_registry.unregister_all(TRAINABLE_CLASS)
- @DeveloperAPI
- def register_env(name: str, env_creator: Callable):
- """Register a custom environment for use with RLlib.
- This enables the environment to be accessed on every Ray process
- in the cluster.
- Args:
- name: Name to register.
- env_creator: Callable that creates an env.
- """
- if not callable(env_creator):
- raise TypeError("Second argument must be callable.", env_creator)
- _global_registry.register(ENV_CREATOR, name, env_creator)
- def _unregister_envs():
- _global_registry.unregister_all(ENV_CREATOR)
- @DeveloperAPI
- def register_input(name: str, input_creator: Callable):
- """Register a custom input api for RLlib.
- Args:
- name: Name to register.
- input_creator: Callable that creates an
- input reader.
- """
- if not callable(input_creator):
- raise TypeError("Second argument must be callable.", input_creator)
- _global_registry.register(RLLIB_INPUT, name, input_creator)
- def _unregister_inputs():
- _global_registry.unregister_all(RLLIB_INPUT)
- @DeveloperAPI
- def registry_contains_input(name: str) -> bool:
- return _global_registry.contains(RLLIB_INPUT, name)
- @DeveloperAPI
- def registry_get_input(name: str) -> Callable:
- return _global_registry.get(RLLIB_INPUT, name)
- def _unregister_all():
- _unregister_inputs()
- _unregister_envs()
- _unregister_trainables()
- def _check_serializability(key, value):
- _global_registry.register(TEST, key, value)
- def _make_key(prefix: str, category: str, key: str):
- """Generate a binary key for the given category and key.
- Args:
- prefix: Prefix
- category: The category of the item
- key: The unique identifier for the item
- Returns:
- The key to use for storing a the value.
- """
- return (
- b"TuneRegistry:"
- + prefix.encode("ascii")
- + b":"
- + category.encode("ascii")
- + b"/"
- + key.encode("ascii")
- )
- class _Registry:
- def __init__(self, prefix: Optional[str] = None):
- """If no prefix is given, use runtime context job ID."""
- self._to_flush = {}
- self._prefix = prefix
- self._registered = set()
- self._atexit_handler_registered = False
- @property
- def prefix(self):
- if not self._prefix:
- self._prefix = ray.get_runtime_context().get_job_id()
- return self._prefix
- def _register_atexit(self):
- if self._atexit_handler_registered:
- # Already registered
- return
- if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE:
- # Only cleanup on the driver
- return
- atexit.register(_unregister_all)
- self._atexit_handler_registered = True
- def register(self, category, key, value):
- """Registers the value with the global registry.
- Args:
- category: The category to register under.
- key: The key to register under.
- value: The value to register.
- Raises:
- PicklingError: If unable to pickle to provided file.
- """
- if category not in KNOWN_CATEGORIES:
- from ray.tune import TuneError
- raise TuneError(
- "Unknown category {} not among {}".format(category, KNOWN_CATEGORIES)
- )
- self._to_flush[(category, key)] = pickle.dumps_debug(value)
- if _internal_kv_initialized():
- self.flush_values()
- def unregister(self, category, key):
- if _internal_kv_initialized():
- _internal_kv_del(_make_key(self.prefix, category, key))
- else:
- self._to_flush.pop((category, key), None)
- def unregister_all(self, category: Optional[str] = None):
- remaining = set()
- for cat, key in self._registered:
- if category and category == cat:
- self.unregister(cat, key)
- else:
- remaining.add((cat, key))
- self._registered = remaining
- def contains(self, category, key):
- if _internal_kv_initialized():
- value = _internal_kv_get(_make_key(self.prefix, category, key))
- return value is not None
- else:
- return (category, key) in self._to_flush
- def get(self, category, key):
- if _internal_kv_initialized():
- value = _internal_kv_get(_make_key(self.prefix, category, key))
- if value is None:
- raise ValueError(
- "Registry value for {}/{} doesn't exist.".format(category, key)
- )
- return pickle.loads(value)
- else:
- return pickle.loads(self._to_flush[(category, key)])
- def flush_values(self):
- self._register_atexit()
- for (category, key), value in self._to_flush.items():
- _internal_kv_put(
- _make_key(self.prefix, category, key), value, overwrite=True
- )
- self._registered.add((category, key))
- self._to_flush.clear()
- _global_registry = _Registry()
- ray._private.worker._post_init_hooks.append(_global_registry.flush_values)
- class _ParameterRegistry:
- def __init__(self):
- self.to_flush = {}
- self.references = {}
- def put(self, k, v):
- self.to_flush[k] = v
- if ray.is_initialized():
- self.flush()
- def get(self, k):
- if not ray.is_initialized():
- return self.to_flush[k]
- return ray.get(self.references[k])
- def flush(self):
- for k, v in self.to_flush.items():
- if isinstance(v, ray.ObjectRef):
- self.references[k] = v
- else:
- self.references[k] = ray.put(v)
- self.to_flush.clear()
|