registry.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import atexit
  2. import logging
  3. from functools import partial
  4. from types import FunctionType
  5. from typing import Callable, Optional, Type, Union
  6. import ray
  7. import ray.cloudpickle as pickle
  8. from ray.experimental.internal_kv import (
  9. _internal_kv_del,
  10. _internal_kv_get,
  11. _internal_kv_initialized,
  12. _internal_kv_put,
  13. )
  14. from ray.tune.error import TuneError
  15. from ray.util.annotations import DeveloperAPI
  16. TRAINABLE_CLASS = "trainable_class"
  17. ENV_CREATOR = "env_creator"
  18. RLLIB_MODEL = "rllib_model"
  19. RLLIB_PREPROCESSOR = "rllib_preprocessor"
  20. RLLIB_ACTION_DIST = "rllib_action_dist"
  21. RLLIB_INPUT = "rllib_input"
  22. RLLIB_CONNECTOR = "rllib_connector"
  23. TEST = "__test__"
  24. KNOWN_CATEGORIES = [
  25. TRAINABLE_CLASS,
  26. ENV_CREATOR,
  27. RLLIB_MODEL,
  28. RLLIB_PREPROCESSOR,
  29. RLLIB_ACTION_DIST,
  30. RLLIB_INPUT,
  31. RLLIB_CONNECTOR,
  32. TEST,
  33. ]
  34. logger = logging.getLogger(__name__)
  35. def _has_trainable(trainable_name):
  36. return _global_registry.contains(TRAINABLE_CLASS, trainable_name)
  37. @DeveloperAPI
  38. def get_trainable_cls(trainable_name):
  39. validate_trainable(trainable_name)
  40. return _global_registry.get(TRAINABLE_CLASS, trainable_name)
  41. @DeveloperAPI
  42. def validate_trainable(trainable_name: str):
  43. if not _has_trainable(trainable_name) and not _has_rllib_trainable(trainable_name):
  44. raise TuneError(f"Unknown trainable: {trainable_name}")
  45. def _has_rllib_trainable(trainable_name: str) -> bool:
  46. try:
  47. # Make sure everything rllib-related is registered.
  48. from ray.rllib import _register_all
  49. except (ImportError, ModuleNotFoundError):
  50. return False
  51. _register_all()
  52. return _has_trainable(trainable_name)
  53. @DeveloperAPI
  54. def is_function_trainable(trainable: Union[str, Callable, Type]) -> bool:
  55. """Check if a given trainable is a function trainable.
  56. Either the trainable has been wrapped as a FunctionTrainable class already,
  57. or it's still a FunctionType/partial/callable."""
  58. from ray.tune.trainable import FunctionTrainable
  59. if isinstance(trainable, str):
  60. trainable = get_trainable_cls(trainable)
  61. is_wrapped_func = isinstance(trainable, type) and issubclass(
  62. trainable, FunctionTrainable
  63. )
  64. return is_wrapped_func or (
  65. not isinstance(trainable, type)
  66. and (
  67. isinstance(trainable, FunctionType)
  68. or isinstance(trainable, partial)
  69. or callable(trainable)
  70. )
  71. )
  72. @DeveloperAPI
  73. def register_trainable(name: str, trainable: Union[Callable, Type], warn: bool = True):
  74. """Register a trainable function or class.
  75. This enables a class or function to be accessed on every Ray process
  76. in the cluster.
  77. Args:
  78. name: Name to register.
  79. trainable: Function or tune.Trainable class. Functions must
  80. take (config, status_reporter) as arguments and will be
  81. automatically converted into a class during registration.
  82. """
  83. from ray.tune.trainable import Trainable, wrap_function
  84. if isinstance(trainable, type):
  85. logger.debug("Detected class for trainable.")
  86. elif isinstance(trainable, FunctionType) or isinstance(trainable, partial):
  87. logger.debug("Detected function for trainable.")
  88. trainable = wrap_function(trainable)
  89. elif callable(trainable):
  90. logger.info("Detected unknown callable for trainable. Converting to class.")
  91. trainable = wrap_function(trainable)
  92. if not issubclass(trainable, Trainable):
  93. raise TypeError("Second argument must be convertable to Trainable", trainable)
  94. _global_registry.register(TRAINABLE_CLASS, name, trainable)
  95. def _unregister_trainables():
  96. _global_registry.unregister_all(TRAINABLE_CLASS)
  97. @DeveloperAPI
  98. def register_env(name: str, env_creator: Callable):
  99. """Register a custom environment for use with RLlib.
  100. This enables the environment to be accessed on every Ray process
  101. in the cluster.
  102. Args:
  103. name: Name to register.
  104. env_creator: Callable that creates an env.
  105. """
  106. if not callable(env_creator):
  107. raise TypeError("Second argument must be callable.", env_creator)
  108. _global_registry.register(ENV_CREATOR, name, env_creator)
  109. def _unregister_envs():
  110. _global_registry.unregister_all(ENV_CREATOR)
  111. @DeveloperAPI
  112. def register_input(name: str, input_creator: Callable):
  113. """Register a custom input api for RLlib.
  114. Args:
  115. name: Name to register.
  116. input_creator: Callable that creates an
  117. input reader.
  118. """
  119. if not callable(input_creator):
  120. raise TypeError("Second argument must be callable.", input_creator)
  121. _global_registry.register(RLLIB_INPUT, name, input_creator)
  122. def _unregister_inputs():
  123. _global_registry.unregister_all(RLLIB_INPUT)
  124. @DeveloperAPI
  125. def registry_contains_input(name: str) -> bool:
  126. return _global_registry.contains(RLLIB_INPUT, name)
  127. @DeveloperAPI
  128. def registry_get_input(name: str) -> Callable:
  129. return _global_registry.get(RLLIB_INPUT, name)
  130. def _unregister_all():
  131. _unregister_inputs()
  132. _unregister_envs()
  133. _unregister_trainables()
  134. def _check_serializability(key, value):
  135. _global_registry.register(TEST, key, value)
  136. def _make_key(prefix: str, category: str, key: str):
  137. """Generate a binary key for the given category and key.
  138. Args:
  139. prefix: Prefix
  140. category: The category of the item
  141. key: The unique identifier for the item
  142. Returns:
  143. The key to use for storing a the value.
  144. """
  145. return (
  146. b"TuneRegistry:"
  147. + prefix.encode("ascii")
  148. + b":"
  149. + category.encode("ascii")
  150. + b"/"
  151. + key.encode("ascii")
  152. )
  153. class _Registry:
  154. def __init__(self, prefix: Optional[str] = None):
  155. """If no prefix is given, use runtime context job ID."""
  156. self._to_flush = {}
  157. self._prefix = prefix
  158. self._registered = set()
  159. self._atexit_handler_registered = False
  160. @property
  161. def prefix(self):
  162. if not self._prefix:
  163. self._prefix = ray.get_runtime_context().get_job_id()
  164. return self._prefix
  165. def _register_atexit(self):
  166. if self._atexit_handler_registered:
  167. # Already registered
  168. return
  169. if ray._private.worker.global_worker.mode != ray.SCRIPT_MODE:
  170. # Only cleanup on the driver
  171. return
  172. atexit.register(_unregister_all)
  173. self._atexit_handler_registered = True
  174. def register(self, category, key, value):
  175. """Registers the value with the global registry.
  176. Args:
  177. category: The category to register under.
  178. key: The key to register under.
  179. value: The value to register.
  180. Raises:
  181. PicklingError: If unable to pickle to provided file.
  182. """
  183. if category not in KNOWN_CATEGORIES:
  184. from ray.tune import TuneError
  185. raise TuneError(
  186. "Unknown category {} not among {}".format(category, KNOWN_CATEGORIES)
  187. )
  188. self._to_flush[(category, key)] = pickle.dumps_debug(value)
  189. if _internal_kv_initialized():
  190. self.flush_values()
  191. def unregister(self, category, key):
  192. if _internal_kv_initialized():
  193. _internal_kv_del(_make_key(self.prefix, category, key))
  194. else:
  195. self._to_flush.pop((category, key), None)
  196. def unregister_all(self, category: Optional[str] = None):
  197. remaining = set()
  198. for cat, key in self._registered:
  199. if category and category == cat:
  200. self.unregister(cat, key)
  201. else:
  202. remaining.add((cat, key))
  203. self._registered = remaining
  204. def contains(self, category, key):
  205. if _internal_kv_initialized():
  206. value = _internal_kv_get(_make_key(self.prefix, category, key))
  207. return value is not None
  208. else:
  209. return (category, key) in self._to_flush
  210. def get(self, category, key):
  211. if _internal_kv_initialized():
  212. value = _internal_kv_get(_make_key(self.prefix, category, key))
  213. if value is None:
  214. raise ValueError(
  215. "Registry value for {}/{} doesn't exist.".format(category, key)
  216. )
  217. return pickle.loads(value)
  218. else:
  219. return pickle.loads(self._to_flush[(category, key)])
  220. def flush_values(self):
  221. self._register_atexit()
  222. for (category, key), value in self._to_flush.items():
  223. _internal_kv_put(
  224. _make_key(self.prefix, category, key), value, overwrite=True
  225. )
  226. self._registered.add((category, key))
  227. self._to_flush.clear()
  228. _global_registry = _Registry()
  229. ray._private.worker._post_init_hooks.append(_global_registry.flush_values)
  230. class _ParameterRegistry:
  231. def __init__(self):
  232. self.to_flush = {}
  233. self.references = {}
  234. def put(self, k, v):
  235. self.to_flush[k] = v
  236. if ray.is_initialized():
  237. self.flush()
  238. def get(self, k):
  239. if not ray.is_initialized():
  240. return self.to_flush[k]
  241. return ray.get(self.references[k])
  242. def flush(self):
  243. for k, v in self.to_flush.items():
  244. if isinstance(v, ray.ObjectRef):
  245. self.references[k] = v
  246. else:
  247. self.references[k] = ray.put(v)
  248. self.to_flush.clear()