placeholder.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. import hashlib
  2. from collections import defaultdict
  3. from typing import Any, Dict, Tuple
  4. from ray.tune.search.sample import Categorical, Domain, Function
  5. from ray.tune.search.variant_generator import assign_value
  6. from ray.util.annotations import DeveloperAPI
  7. ID_HASH_LENGTH = 8
  8. def create_resolvers_map():
  9. return defaultdict(list)
  10. def _id_hash(path_tuple):
  11. """Compute a hash for the specific placeholder based on its path."""
  12. return hashlib.sha256(str(path_tuple).encode("utf-8")).hexdigest()[:ID_HASH_LENGTH]
  13. class _FunctionResolver:
  14. """Replaced value for function typed objects."""
  15. TOKEN = "__fn_ph"
  16. def __init__(self, hash, fn):
  17. self.hash = hash
  18. self._fn = fn
  19. def resolve(self, config: Dict):
  20. """Some functions take a resolved spec dict as input.
  21. Note: Function placeholders are independently sampled during
  22. resolution. Therefore their random states are not restored.
  23. """
  24. return self._fn.sample(config=config)
  25. def get_placeholder(self) -> str:
  26. return (self.TOKEN, self.hash)
  27. class _RefResolver:
  28. """Replaced value for all other non-primitive objects."""
  29. TOKEN = "__ref_ph"
  30. def __init__(self, hash, value):
  31. self.hash = hash
  32. self._value = value
  33. def resolve(self):
  34. return self._value
  35. def get_placeholder(self) -> str:
  36. return (self.TOKEN, self.hash)
  37. def _is_primitive(x):
  38. """Returns True if x is a primitive type.
  39. Primitive types are int, float, str, bool, and None.
  40. """
  41. return isinstance(x, (int, float, str, bool)) or x is None
  42. @DeveloperAPI
  43. def inject_placeholders(
  44. config: Any,
  45. resolvers: defaultdict,
  46. id_prefix: Tuple = (),
  47. path_prefix: Tuple = (),
  48. ) -> Dict:
  49. """Replaces reference objects contained by a config dict with placeholders.
  50. Given a config dict, this function replaces all reference objects contained
  51. by this dict with placeholder strings. It recursively expands nested dicts
  52. and lists, and properly handles Tune native search objects such as Categorical
  53. and Function.
  54. This makes sure the config dict only contains primitive typed values, which
  55. can then be handled by different search algorithms.
  56. A few details about id_prefix and path_prefix. Consider the following config,
  57. where "param1" is a simple grid search of 3 tuples.
  58. config = {
  59. "param1": tune.grid_search([
  60. (Cat, None, None),
  61. (None, Dog, None),
  62. (None, None, Fish),
  63. ]),
  64. }
  65. We will replace the 3 objects contained with placeholders. And after trial
  66. expansion, the config may look like this:
  67. config = {
  68. "param1": (None, (placeholder, hash), None)
  69. }
  70. Now you need 2 pieces of information to resolve the placeholder. One is the
  71. path of ("param1", 1), which tells you that the first element of the tuple
  72. under "param1" key is a placeholder that needs to be resolved.
  73. The other is the mapping from the placeholder to the actual object. In this
  74. case hash -> Dog.
  75. id and path prefixes serve exactly this purpose here. The difference between
  76. these two is that id_prefix is the location of the value in the pre-injected
  77. config tree. So if a value is the second option in a grid_search, it gets an
  78. id part of 1. Injected placeholders all get unique id prefixes. path prefix
  79. identifies a placeholder in the expanded config tree. So for example, all
  80. options of a single grid_search will get the same path prefix. This is how
  81. we know which location has a placeholder to be resolved in the post-expansion
  82. tree.
  83. Args:
  84. config: The config dict to replace references in.
  85. resolvers: A dict from path to replaced objects.
  86. id_prefix: The prefix to prepend to id every single placeholders.
  87. path_prefix: The prefix to prepend to every path identifying
  88. potential locations of placeholders in an expanded tree.
  89. Returns:
  90. The config with all references replaced.
  91. """
  92. if isinstance(config, dict) and "grid_search" in config and len(config) == 1:
  93. config["grid_search"] = [
  94. # Different options gets different id prefixes.
  95. # But we should omit appending to path_prefix because after expansion,
  96. # this level will not be there.
  97. inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
  98. for i, choice in enumerate(config["grid_search"])
  99. ]
  100. return config
  101. elif isinstance(config, dict):
  102. return {
  103. k: inject_placeholders(v, resolvers, id_prefix + (k,), path_prefix + (k,))
  104. for k, v in config.items()
  105. }
  106. elif isinstance(config, list):
  107. return [
  108. inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
  109. for i, elem in enumerate(config)
  110. ]
  111. elif isinstance(config, tuple):
  112. return tuple(
  113. inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
  114. for i, elem in enumerate(config)
  115. )
  116. elif _is_primitive(config):
  117. # Primitive types.
  118. return config
  119. elif isinstance(config, Categorical):
  120. config.categories = [
  121. # Different options gets different id prefixes.
  122. # But we should omit appending to path_prefix because after expansion,
  123. # this level will not be there.
  124. inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
  125. for i, choice in enumerate(config.categories)
  126. ]
  127. return config
  128. elif isinstance(config, Function):
  129. # Function type.
  130. id_hash = _id_hash(id_prefix)
  131. v = _FunctionResolver(id_hash, config)
  132. resolvers[path_prefix].append(v)
  133. return v.get_placeholder()
  134. elif not isinstance(config, Domain):
  135. # Other non-search space reference objects, dataset, actor handle, etc.
  136. id_hash = _id_hash(id_prefix)
  137. v = _RefResolver(id_hash, config)
  138. resolvers[path_prefix].append(v)
  139. return v.get_placeholder()
  140. else:
  141. # All the other cases, do nothing.
  142. return config
  143. def _get_placeholder(config: Any, prefix: Tuple, path: Tuple):
  144. if not path:
  145. return prefix, config
  146. key = path[0]
  147. if isinstance(config, tuple):
  148. if config[0] in (_FunctionResolver.TOKEN, _RefResolver.TOKEN):
  149. # Found a matching placeholder.
  150. # Note that we do not require that the full path are consumed before
  151. # declaring a match. Because this placeholder may be part of a nested
  152. # search space. For example, the following config:
  153. # config = {
  154. # "param1": tune.grid_search([
  155. # tune.grid_search([Object1, 2, 3]),
  156. # tune.grid_search([Object2, 5, 6]),
  157. # ]),
  158. # }
  159. # will result in placeholders under path ("param1", 0, 0).
  160. # After expansion though, the choosen placeholder will live under path
  161. # ("param1", 0) like this: config = {"param1": (Placeholder1, 2, 3)}
  162. return prefix, config
  163. elif key < len(config):
  164. return _get_placeholder(
  165. config[key], prefix=prefix + (path[0],), path=path[1:]
  166. )
  167. elif (isinstance(config, dict) and key in config) or (
  168. isinstance(config, list) and key < len(config)
  169. ):
  170. # Expand config tree recursively.
  171. return _get_placeholder(config[key], prefix=prefix + (path[0],), path=path[1:])
  172. # Can not find a matching placeholder.
  173. return None, None
  174. @DeveloperAPI
  175. def resolve_placeholders(config: Any, replaced: defaultdict):
  176. """Replaces placeholders contained by a config dict with the original values.
  177. Args:
  178. config: The config to replace placeholders in.
  179. replaced: A dict from path to replaced objects.
  180. """
  181. def __resolve(resolver_type, args):
  182. for path, resolvers in replaced.items():
  183. assert resolvers
  184. if not isinstance(resolvers[0], resolver_type):
  185. continue
  186. prefix, ph = _get_placeholder(config, (), path)
  187. if not ph:
  188. # Represents an unchosen value. Just skip.
  189. continue
  190. for resolver in resolvers:
  191. if resolver.hash != ph[1]:
  192. continue
  193. # Found the matching resolver.
  194. assign_value(config, prefix, resolver.resolve(*args))
  195. # RefResolvers first.
  196. __resolve(_RefResolver, args=())
  197. # Functions need to be resolved after RefResolvers, in case they are
  198. # referencing values from the RefResolvers.
  199. __resolve(_FunctionResolver, args=(config,))