| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523 |
- import copy
- import logging
- import random
- import re
- from collections.abc import Mapping
- from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
- import numpy
- from ray.tune.search.sample import Categorical, Domain, Function, RandomState
- from ray.util.annotations import DeveloperAPI, PublicAPI
- logger = logging.getLogger(__name__)
- @DeveloperAPI
- def generate_variants(
- unresolved_spec: Dict,
- constant_grid_search: bool = False,
- random_state: "RandomState" = None,
- ) -> Generator[Tuple[Dict, Dict], None, None]:
- """Generates variants from a spec (dict) with unresolved values.
- There are two types of unresolved values:
- Grid search: These define a grid search over values. For example, the
- following grid search values in a spec will produce six distinct
- variants in combination:
- "activation": grid_search(["relu", "tanh"])
- "learning_rate": grid_search([1e-3, 1e-4, 1e-5])
- Lambda functions: These are evaluated to produce a concrete value, and
- can express dependencies or conditional distributions between values.
- They can also be used to express random search (e.g., by calling
- into the `random` or `np` module).
- "cpu": lambda spec: spec.config.num_workers
- "batch_size": lambda spec: random.uniform(1, 1000)
- Finally, to support defining specs in plain JSON / YAML, grid search
- and lambda functions can also be defined alternatively as follows:
- "activation": {"grid_search": ["relu", "tanh"]}
- "cpu": {"eval": "spec.config.num_workers"}
- Use `format_vars` to format the returned dict of hyperparameters.
- Yields:
- (Dict of resolved variables, Spec object)
- """
- for resolved_vars, spec in _generate_variants_internal(
- unresolved_spec,
- constant_grid_search=constant_grid_search,
- random_state=random_state,
- ):
- assert not _unresolved_values(spec)
- yield resolved_vars, spec
- @PublicAPI(stability="beta")
- def grid_search(values: Iterable) -> Dict[str, Iterable]:
- """Specify a grid of values to search over.
- Values specified in a grid search are guaranteed to be sampled.
- If multiple grid search variables are defined, they are combined with the
- combinatorial product. This means every possible combination of values will
- be sampled.
- Example:
- >>> from ray import tune
- >>> param_space={
- ... "x": tune.grid_search([10, 20]),
- ... "y": tune.grid_search(["a", "b", "c"])
- ... }
- This will create a grid of 6 samples:
- ``{"x": 10, "y": "a"}``, ``{"x": 10, "y": "b"}``, etc.
- When specifying ``num_samples`` in the
- :class:`TuneConfig <ray.tune.tune_config.TuneConfig>`, this will specify
- the number of random samples per grid search combination.
- For instance, in the example above, if ``num_samples=4``,
- a total of 24 trials will be started -
- 4 trials for each of the 6 grid search combinations.
- Args:
- values: An iterable whose parameters will be used for creating a trial grid.
- """
- return {"grid_search": values}
- _STANDARD_IMPORTS = {
- "random": random,
- "np": numpy,
- }
- _MAX_RESOLUTION_PASSES = 20
- def _resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]:
- """Flattens a nested dict by joining keys into tuple of paths.
- Can then be passed into `format_vars`.
- """
- res = {}
- for k, v in nested_dict.items():
- if isinstance(v, dict):
- for k_, v_ in _resolve_nested_dict(v).items():
- res[(k,) + k_] = v_
- else:
- res[(k,)] = v
- return res
- @DeveloperAPI
- def format_vars(resolved_vars: Dict) -> str:
- """Format variables to be used as experiment tags.
- Experiment tags are used in directory names, so this method makes sure
- the resulting tags can be legally used in directory names on all systems.
- The input to this function is a dict of the form
- ``{("nested", "config", "path"): "value"}``. The output will be a comma
- separated string of the form ``last_key=value``, so in this example
- ``path=value``.
- Note that the sanitizing implies that empty strings are possible return
- values. This is expected and acceptable, as it is not a common case and
- the resulting directory names will still be valid.
- Args:
- resolved_vars: Dictionary mapping from config path tuples to a value.
- Returns:
- Comma-separated key=value string.
- """
- vars = resolved_vars.copy()
- # TrialRunner already has these in the experiment_tag
- for v in ["run", "env", "resources_per_trial"]:
- vars.pop(v, None)
- return ",".join(
- f"{_clean_value(k[-1])}={_clean_value(v)}" for k, v in sorted(vars.items())
- )
- def _flatten_resolved_vars(resolved_vars: Dict) -> Dict:
- """Formats the resolved variable dict into a mapping of (str -> value)."""
- flattened_resolved_vars_dict = {}
- for pieces, value in resolved_vars.items():
- if pieces[0] == "config":
- pieces = pieces[1:]
- pieces = [str(piece) for piece in pieces]
- flattened_resolved_vars_dict["/".join(pieces)] = value
- return flattened_resolved_vars_dict
- def _clean_value(value: Any) -> str:
- """Format floats and replace invalid string characters with ``_``."""
- if isinstance(value, float):
- return f"{value:.4f}"
- else:
- # Define an invalid alphabet, which is the inverse of the
- # stated regex characters
- invalid_alphabet = r"[^a-zA-Z0-9_-]+"
- return re.sub(invalid_alphabet, "_", str(value)).strip("_")
- @DeveloperAPI
- def parse_spec_vars(
- spec: Dict,
- ) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]]]:
- resolved, unresolved = _split_resolved_unresolved_values(spec)
- resolved_vars = list(resolved.items())
- if not unresolved:
- return resolved_vars, [], []
- grid_vars = []
- domain_vars = []
- for path, value in unresolved.items():
- if value.is_grid():
- grid_vars.append((path, value))
- else:
- domain_vars.append((path, value))
- grid_vars.sort()
- return resolved_vars, domain_vars, grid_vars
- def _count_spec_samples(spec: Dict, num_samples=1) -> int:
- """Count samples for a specific spec"""
- _, domain_vars, grid_vars = parse_spec_vars(spec)
- grid_count = 1
- for path, domain in grid_vars:
- grid_count *= len(domain.categories)
- return num_samples * grid_count
- def _count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
- # Helper function: Deep update dictionary
- def deep_update(d, u):
- for k, v in u.items():
- if isinstance(v, Mapping):
- d[k] = deep_update(d.get(k, {}), v)
- else:
- d[k] = v
- return d
- total_samples = 0
- total_num_samples = spec.get("num_samples", 1)
- # For each preset, overwrite the spec and count the samples generated
- # for this preset
- for preset in presets:
- preset_spec = copy.deepcopy(spec)
- deep_update(preset_spec["config"], preset)
- total_samples += _count_spec_samples(preset_spec, 1)
- total_num_samples -= 1
- # Add the remaining samples
- if total_num_samples > 0:
- total_samples += _count_spec_samples(spec, total_num_samples)
- return total_samples
- def _generate_variants_internal(
- spec: Dict, constant_grid_search: bool = False, random_state: "RandomState" = None
- ) -> Tuple[Dict, Dict]:
- spec = copy.deepcopy(spec)
- _, domain_vars, grid_vars = parse_spec_vars(spec)
- if not domain_vars and not grid_vars:
- yield {}, spec
- return
- # Variables to resolve
- to_resolve = domain_vars
- all_resolved = True
- if constant_grid_search:
- # In this path, we first sample random variables and keep them constant
- # for grid search.
- # `_resolve_domain_vars` will alter `spec` directly
- all_resolved, resolved_vars = _resolve_domain_vars(
- spec, domain_vars, allow_fail=True, random_state=random_state
- )
- if not all_resolved:
- # Not all variables have been resolved, but remove those that have
- # from the `to_resolve` list.
- to_resolve = [(r, d) for r, d in to_resolve if r not in resolved_vars]
- grid_search = _grid_search_generator(spec, grid_vars)
- for resolved_spec in grid_search:
- if not constant_grid_search or not all_resolved:
- # In this path, we sample the remaining random variables
- _, resolved_vars = _resolve_domain_vars(
- resolved_spec, to_resolve, random_state=random_state
- )
- for resolved, spec in _generate_variants_internal(
- resolved_spec,
- constant_grid_search=constant_grid_search,
- random_state=random_state,
- ):
- for path, value in grid_vars:
- resolved_vars[path] = _get_value(spec, path)
- for k, v in resolved.items():
- if (
- k in resolved_vars
- and v != resolved_vars[k]
- and _is_resolved(resolved_vars[k])
- ):
- raise ValueError(
- "The variable `{}` could not be unambiguously "
- "resolved to a single value. Consider simplifying "
- "your configuration.".format(k)
- )
- resolved_vars[k] = v
- yield resolved_vars, spec
- def _get_preset_variants(
- spec: Dict,
- config: Dict,
- constant_grid_search: bool = False,
- random_state: "RandomState" = None,
- ):
- """Get variants according to a spec, initialized with a config.
- Variables from the spec are overwritten by the variables in the config.
- Thus, we may end up with less sampled parameters.
- This function also checks if values used to overwrite search space
- parameters are valid, and logs a warning if not.
- """
- spec = copy.deepcopy(spec)
- resolved, _, _ = parse_spec_vars(config)
- for path, val in resolved:
- try:
- domain = _get_value(spec["config"], path)
- if isinstance(domain, dict):
- if "grid_search" in domain:
- domain = Categorical(domain["grid_search"])
- else:
- # If users want to overwrite an entire subdict,
- # let them do it.
- domain = None
- except IndexError as exc:
- raise ValueError(
- f"Pre-set config key `{'/'.join(path)}` does not correspond "
- f"to a valid key in the search space definition. Please add "
- f"this path to the `param_space` variable passed to `tune.Tuner()`."
- ) from exc
- if domain:
- if isinstance(domain, Domain):
- if not domain.is_valid(val):
- logger.warning(
- f"Pre-set value `{val}` is not within valid values of "
- f"parameter `{'/'.join(path)}`: {domain.domain_str}"
- )
- else:
- # domain is actually a fixed value
- if domain != val:
- logger.warning(
- f"Pre-set value `{val}` is not equal to the value of "
- f"parameter `{'/'.join(path)}`: {domain}"
- )
- assign_value(spec["config"], path, val)
- return _generate_variants_internal(
- spec, constant_grid_search=constant_grid_search, random_state=random_state
- )
- @DeveloperAPI
- def assign_value(spec: Dict, path: Tuple, value: Any):
- """Assigns a value to a nested dictionary.
- Handles the special case of tuples, in which case the tuples
- will be re-constructed to accommodate the updated value.
- """
- parent_spec = None
- parent_key = None
- for k in path[:-1]:
- parent_spec = spec
- parent_key = k
- spec = spec[k]
- key = path[-1]
- if not isinstance(spec, tuple):
- # spec is mutable. Just assign the value.
- spec[key] = value
- else:
- if parent_spec is None:
- raise ValueError("Cannot assign value to a tuple.")
- assert isinstance(key, int), "Tuple key must be an int."
- # Special handling since tuples are immutable.
- parent_spec[parent_key] = spec[:key] + (value,) + spec[key + 1 :]
- def _get_value(spec: Dict, path: Tuple) -> Any:
- for k in path:
- spec = spec[k]
- return spec
- def _resolve_domain_vars(
- spec: Dict,
- domain_vars: List[Tuple[Tuple, Domain]],
- allow_fail: bool = False,
- random_state: "RandomState" = None,
- ) -> Tuple[bool, Dict]:
- resolved = {}
- error = True
- num_passes = 0
- while error and num_passes < _MAX_RESOLUTION_PASSES:
- num_passes += 1
- error = False
- for path, domain in domain_vars:
- if path in resolved:
- continue
- try:
- value = domain.sample(
- _UnresolvedAccessGuard(spec), random_state=random_state
- )
- except RecursiveDependencyError as e:
- error = e
- except Exception:
- raise ValueError(
- "Failed to evaluate expression: {}: {}".format(path, domain)
- )
- else:
- assign_value(spec, path, value)
- resolved[path] = value
- if error:
- if not allow_fail:
- raise error
- else:
- return False, resolved
- return True, resolved
- def _grid_search_generator(
- unresolved_spec: Dict, grid_vars: List
- ) -> Generator[Dict, None, None]:
- value_indices = [0] * len(grid_vars)
- def increment(i):
- value_indices[i] += 1
- if value_indices[i] >= len(grid_vars[i][1]):
- value_indices[i] = 0
- if i + 1 < len(value_indices):
- return increment(i + 1)
- else:
- return True
- return False
- if not grid_vars:
- yield unresolved_spec
- return
- while value_indices[-1] < len(grid_vars[-1][1]):
- spec = copy.deepcopy(unresolved_spec)
- for i, (path, values) in enumerate(grid_vars):
- assign_value(spec, path, values[value_indices[i]])
- yield spec
- if grid_vars:
- done = increment(0)
- if done:
- break
- def _is_resolved(v) -> bool:
- resolved, _ = _try_resolve(v)
- return resolved
- def _try_resolve(v) -> Tuple[bool, Any]:
- if isinstance(v, Domain):
- # Domain to sample from
- return False, v
- elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
- # Lambda function in eval syntax
- return False, Function(
- lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec})
- )
- elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
- # Grid search values
- grid_values = v["grid_search"]
- return False, Categorical(grid_values).grid()
- return True, v
- def _split_resolved_unresolved_values(
- spec: Dict,
- ) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]:
- resolved_vars = {}
- unresolved_vars = {}
- for k, v in spec.items():
- resolved, v = _try_resolve(v)
- if not resolved:
- unresolved_vars[(k,)] = v
- elif isinstance(v, dict):
- # Recurse into a dict
- (
- _resolved_children,
- _unresolved_children,
- ) = _split_resolved_unresolved_values(v)
- for path, value in _resolved_children.items():
- resolved_vars[(k,) + path] = value
- for path, value in _unresolved_children.items():
- unresolved_vars[(k,) + path] = value
- elif isinstance(v, (list, tuple)):
- # Recurse into a list
- for i, elem in enumerate(v):
- (
- _resolved_children,
- _unresolved_children,
- ) = _split_resolved_unresolved_values({i: elem})
- for path, value in _resolved_children.items():
- resolved_vars[(k,) + path] = value
- for path, value in _unresolved_children.items():
- unresolved_vars[(k,) + path] = value
- else:
- resolved_vars[(k,)] = v
- return resolved_vars, unresolved_vars
- def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]:
- return _split_resolved_unresolved_values(spec)[1]
- def _has_unresolved_values(spec: Dict) -> bool:
- return True if _unresolved_values(spec) else False
- class _UnresolvedAccessGuard(dict):
- def __init__(self, *args, **kwds):
- super(_UnresolvedAccessGuard, self).__init__(*args, **kwds)
- self.__dict__ = self
- def __getattribute__(self, item):
- value = dict.__getattribute__(self, item)
- if not _is_resolved(value):
- raise RecursiveDependencyError(
- "`{}` recursively depends on {}".format(item, value)
- )
- elif isinstance(value, dict):
- return _UnresolvedAccessGuard(value)
- else:
- return value
- @DeveloperAPI
- class RecursiveDependencyError(Exception):
- def __init__(self, msg: str):
- Exception.__init__(self, msg)
|