| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- import hashlib
- from collections import defaultdict
- from typing import Any, Dict, Tuple
- from ray.tune.search.sample import Categorical, Domain, Function
- from ray.tune.search.variant_generator import assign_value
- from ray.util.annotations import DeveloperAPI
- ID_HASH_LENGTH = 8
- def create_resolvers_map():
- return defaultdict(list)
- def _id_hash(path_tuple):
- """Compute a hash for the specific placeholder based on its path."""
- return hashlib.sha256(str(path_tuple).encode("utf-8")).hexdigest()[:ID_HASH_LENGTH]
- class _FunctionResolver:
- """Replaced value for function typed objects."""
- TOKEN = "__fn_ph"
- def __init__(self, hash, fn):
- self.hash = hash
- self._fn = fn
- def resolve(self, config: Dict):
- """Some functions take a resolved spec dict as input.
- Note: Function placeholders are independently sampled during
- resolution. Therefore their random states are not restored.
- """
- return self._fn.sample(config=config)
- def get_placeholder(self) -> str:
- return (self.TOKEN, self.hash)
- class _RefResolver:
- """Replaced value for all other non-primitive objects."""
- TOKEN = "__ref_ph"
- def __init__(self, hash, value):
- self.hash = hash
- self._value = value
- def resolve(self):
- return self._value
- def get_placeholder(self) -> str:
- return (self.TOKEN, self.hash)
- def _is_primitive(x):
- """Returns True if x is a primitive type.
- Primitive types are int, float, str, bool, and None.
- """
- return isinstance(x, (int, float, str, bool)) or x is None
- @DeveloperAPI
- def inject_placeholders(
- config: Any,
- resolvers: defaultdict,
- id_prefix: Tuple = (),
- path_prefix: Tuple = (),
- ) -> Dict:
- """Replaces reference objects contained by a config dict with placeholders.
- Given a config dict, this function replaces all reference objects contained
- by this dict with placeholder strings. It recursively expands nested dicts
- and lists, and properly handles Tune native search objects such as Categorical
- and Function.
- This makes sure the config dict only contains primitive typed values, which
- can then be handled by different search algorithms.
- A few details about id_prefix and path_prefix. Consider the following config,
- where "param1" is a simple grid search of 3 tuples.
- config = {
- "param1": tune.grid_search([
- (Cat, None, None),
- (None, Dog, None),
- (None, None, Fish),
- ]),
- }
- We will replace the 3 objects contained with placeholders. And after trial
- expansion, the config may look like this:
- config = {
- "param1": (None, (placeholder, hash), None)
- }
- Now you need 2 pieces of information to resolve the placeholder. One is the
- path of ("param1", 1), which tells you that the first element of the tuple
- under "param1" key is a placeholder that needs to be resolved.
- The other is the mapping from the placeholder to the actual object. In this
- case hash -> Dog.
- id and path prefixes serve exactly this purpose here. The difference between
- these two is that id_prefix is the location of the value in the pre-injected
- config tree. So if a value is the second option in a grid_search, it gets an
- id part of 1. Injected placeholders all get unique id prefixes. path prefix
- identifies a placeholder in the expanded config tree. So for example, all
- options of a single grid_search will get the same path prefix. This is how
- we know which location has a placeholder to be resolved in the post-expansion
- tree.
- Args:
- config: The config dict to replace references in.
- resolvers: A dict from path to replaced objects.
- id_prefix: The prefix to prepend to id every single placeholders.
- path_prefix: The prefix to prepend to every path identifying
- potential locations of placeholders in an expanded tree.
- Returns:
- The config with all references replaced.
- """
- if isinstance(config, dict) and "grid_search" in config and len(config) == 1:
- config["grid_search"] = [
- # Different options gets different id prefixes.
- # But we should omit appending to path_prefix because after expansion,
- # this level will not be there.
- inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
- for i, choice in enumerate(config["grid_search"])
- ]
- return config
- elif isinstance(config, dict):
- return {
- k: inject_placeholders(v, resolvers, id_prefix + (k,), path_prefix + (k,))
- for k, v in config.items()
- }
- elif isinstance(config, list):
- return [
- inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
- for i, elem in enumerate(config)
- ]
- elif isinstance(config, tuple):
- return tuple(
- inject_placeholders(elem, resolvers, id_prefix + (i,), path_prefix + (i,))
- for i, elem in enumerate(config)
- )
- elif _is_primitive(config):
- # Primitive types.
- return config
- elif isinstance(config, Categorical):
- config.categories = [
- # Different options gets different id prefixes.
- # But we should omit appending to path_prefix because after expansion,
- # this level will not be there.
- inject_placeholders(choice, resolvers, id_prefix + (i,), path_prefix)
- for i, choice in enumerate(config.categories)
- ]
- return config
- elif isinstance(config, Function):
- # Function type.
- id_hash = _id_hash(id_prefix)
- v = _FunctionResolver(id_hash, config)
- resolvers[path_prefix].append(v)
- return v.get_placeholder()
- elif not isinstance(config, Domain):
- # Other non-search space reference objects, dataset, actor handle, etc.
- id_hash = _id_hash(id_prefix)
- v = _RefResolver(id_hash, config)
- resolvers[path_prefix].append(v)
- return v.get_placeholder()
- else:
- # All the other cases, do nothing.
- return config
- def _get_placeholder(config: Any, prefix: Tuple, path: Tuple):
- if not path:
- return prefix, config
- key = path[0]
- if isinstance(config, tuple):
- if config[0] in (_FunctionResolver.TOKEN, _RefResolver.TOKEN):
- # Found a matching placeholder.
- # Note that we do not require that the full path are consumed before
- # declaring a match. Because this placeholder may be part of a nested
- # search space. For example, the following config:
- # config = {
- # "param1": tune.grid_search([
- # tune.grid_search([Object1, 2, 3]),
- # tune.grid_search([Object2, 5, 6]),
- # ]),
- # }
- # will result in placeholders under path ("param1", 0, 0).
- # After expansion though, the choosen placeholder will live under path
- # ("param1", 0) like this: config = {"param1": (Placeholder1, 2, 3)}
- return prefix, config
- elif key < len(config):
- return _get_placeholder(
- config[key], prefix=prefix + (path[0],), path=path[1:]
- )
- elif (isinstance(config, dict) and key in config) or (
- isinstance(config, list) and key < len(config)
- ):
- # Expand config tree recursively.
- return _get_placeholder(config[key], prefix=prefix + (path[0],), path=path[1:])
- # Can not find a matching placeholder.
- return None, None
- @DeveloperAPI
- def resolve_placeholders(config: Any, replaced: defaultdict):
- """Replaces placeholders contained by a config dict with the original values.
- Args:
- config: The config to replace placeholders in.
- replaced: A dict from path to replaced objects.
- """
- def __resolve(resolver_type, args):
- for path, resolvers in replaced.items():
- assert resolvers
- if not isinstance(resolvers[0], resolver_type):
- continue
- prefix, ph = _get_placeholder(config, (), path)
- if not ph:
- # Represents an unchosen value. Just skip.
- continue
- for resolver in resolvers:
- if resolver.hash != ph[1]:
- continue
- # Found the matching resolver.
- assign_value(config, prefix, resolver.resolve(*args))
- # RefResolvers first.
- __resolve(_RefResolver, args=())
- # Functions need to be resolved after RefResolvers, in case they are
- # referencing values from the RefResolvers.
- __resolve(_FunctionResolver, args=(config,))
|