__init__.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import contextlib
  2. from collections import deque
  3. from functools import partial
  4. from typing import Any, Dict, List, Optional, Tuple, Union
  5. import tree
  6. from ray._common.deprecation import deprecation_warning
  7. from ray.rllib.utils.annotations import DeveloperAPI, PublicAPI, override
  8. from ray.rllib.utils.filter import Filter
  9. from ray.rllib.utils.filter_manager import FilterManager
  10. from ray.rllib.utils.framework import (
  11. try_import_jax,
  12. try_import_tf,
  13. try_import_tfp,
  14. try_import_torch,
  15. )
  16. from ray.rllib.utils.numpy import (
  17. LARGE_INTEGER,
  18. MAX_LOG_NN_OUTPUT,
  19. MIN_LOG_NN_OUTPUT,
  20. SMALL_NUMBER,
  21. fc,
  22. lstm,
  23. one_hot,
  24. relu,
  25. sigmoid,
  26. softmax,
  27. )
  28. from ray.rllib.utils.schedules import (
  29. ConstantSchedule,
  30. ExponentialSchedule,
  31. LinearSchedule,
  32. PiecewiseSchedule,
  33. PolynomialSchedule,
  34. )
  35. from ray.rllib.utils.test_utils import (
  36. check,
  37. check_compute_single_action,
  38. check_train_results,
  39. )
  40. from ray.tune.utils import deep_update, merge_dicts
  41. @DeveloperAPI
  42. def add_mixins(base, mixins, reversed=False):
  43. """Returns a new class with mixins applied in priority order."""
  44. mixins = list(mixins or [])
  45. while mixins:
  46. if reversed:
  47. class new_base(base, mixins.pop()):
  48. pass
  49. else:
  50. class new_base(mixins.pop(), base):
  51. pass
  52. base = new_base
  53. return base
  54. @DeveloperAPI
  55. def force_list(
  56. elements: Optional[Any] = None, to_tuple: bool = False
  57. ) -> Union[List, Tuple]:
  58. """
  59. Makes sure `elements` is returned as a list, whether `elements` is a single
  60. item, already a list, or a tuple.
  61. Args:
  62. elements: The inputs as a single item, a list/tuple/deque of items, or None,
  63. to be converted to a list/tuple. If None, returns empty list/tuple.
  64. to_tuple: Whether to use tuple (instead of list).
  65. Returns:
  66. The provided item in a list of size 1, or the provided items as a
  67. list. If `elements` is None, returns an empty list. If `to_tuple` is True,
  68. returns a tuple instead of a list.
  69. """
  70. ctor = list
  71. if to_tuple is True:
  72. ctor = tuple
  73. return (
  74. ctor()
  75. if elements is None
  76. else ctor(elements)
  77. if type(elements) in [list, set, tuple, deque]
  78. else ctor([elements])
  79. )
  80. @DeveloperAPI
  81. def flatten_dict(nested: Dict[str, Any], sep="/", env_steps=0) -> Dict[str, Any]:
  82. """
  83. Flattens a nested dict into a flat dict with joined keys.
  84. Note, this is used for better serialization of nested dictionaries
  85. in `OfflinePreLearner.__call__` when called inside
  86. `ray.data.Dataset.map_batches`.
  87. Note, this is used to return a `Dict[str, numpy.ndarray] from the
  88. `__call__` method which is expected by Ray Data.
  89. Args:
  90. nested: A nested dictionary.
  91. sep: Separator to use when joining keys.
  92. Returns:
  93. A flat dictionary where each key is a path of keys in the nested dict.
  94. """
  95. flat = {}
  96. # `dm_tree.flatten_with_path`` returns a list of `(path, leaf)` tuples.
  97. for path, leaf in tree.flatten_with_path(nested):
  98. # Create a single string key from the path.
  99. key = sep.join(map(str, path))
  100. flat[key] = leaf
  101. return flat
  102. @DeveloperAPI
  103. def unflatten_dict(flat: Dict[str, Any], sep="/") -> Dict[str, Any]:
  104. """
  105. Reconstructs a nested dict from a flat dict with joined keys.
  106. Note, this is used for better deserialization ofr nested dictionaries
  107. in `Learner.update' calls in which a `ray.data.DataIterator` is used.
  108. Args:
  109. flat: A flat dictionary with keys that are paths joined by `sep`.
  110. sep: The separator used in the flat dictionary keys.
  111. Returns:
  112. A nested dictionary.
  113. """
  114. nested = {}
  115. for compound_key, value in flat.items():
  116. # Split all keys by the separator.
  117. keys = compound_key.split(sep)
  118. current = nested
  119. # Nest by the separated keys.
  120. for key in keys[:-1]:
  121. if key not in current:
  122. current[key] = {}
  123. current = current[key]
  124. current[keys[-1]] = value
  125. return nested
  126. @DeveloperAPI
  127. class NullContextManager(contextlib.AbstractContextManager):
  128. """No-op context manager"""
  129. def __init__(self):
  130. pass
  131. def __enter__(self):
  132. pass
  133. def __exit__(self, *args):
  134. pass
  135. force_tuple = partial(force_list, to_tuple=True)
  136. __all__ = [
  137. "add_mixins",
  138. "check",
  139. "check_compute_single_action",
  140. "check_train_results",
  141. "deep_update",
  142. "deprecation_warning",
  143. "fc",
  144. "force_list",
  145. "force_tuple",
  146. "flatten_dict",
  147. "unflatten_dict",
  148. "lstm",
  149. "merge_dicts",
  150. "one_hot",
  151. "override",
  152. "relu",
  153. "sigmoid",
  154. "softmax",
  155. "try_import_jax",
  156. "try_import_tf",
  157. "try_import_tfp",
  158. "try_import_torch",
  159. "ConstantSchedule",
  160. "DeveloperAPI",
  161. "ExponentialSchedule",
  162. "Filter",
  163. "FilterManager",
  164. "LARGE_INTEGER",
  165. "LinearSchedule",
  166. "MAX_LOG_NN_OUTPUT",
  167. "MIN_LOG_NN_OUTPUT",
  168. "PiecewiseSchedule",
  169. "PolynomialSchedule",
  170. "PublicAPI",
  171. "SMALL_NUMBER",
  172. ]