variant_generator.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523
  1. import copy
  2. import logging
  3. import random
  4. import re
  5. from collections.abc import Mapping
  6. from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
  7. import numpy
  8. from ray.tune.search.sample import Categorical, Domain, Function, RandomState
  9. from ray.util.annotations import DeveloperAPI, PublicAPI
  10. logger = logging.getLogger(__name__)
  11. @DeveloperAPI
  12. def generate_variants(
  13. unresolved_spec: Dict,
  14. constant_grid_search: bool = False,
  15. random_state: "RandomState" = None,
  16. ) -> Generator[Tuple[Dict, Dict], None, None]:
  17. """Generates variants from a spec (dict) with unresolved values.
  18. There are two types of unresolved values:
  19. Grid search: These define a grid search over values. For example, the
  20. following grid search values in a spec will produce six distinct
  21. variants in combination:
  22. "activation": grid_search(["relu", "tanh"])
  23. "learning_rate": grid_search([1e-3, 1e-4, 1e-5])
  24. Lambda functions: These are evaluated to produce a concrete value, and
  25. can express dependencies or conditional distributions between values.
  26. They can also be used to express random search (e.g., by calling
  27. into the `random` or `np` module).
  28. "cpu": lambda spec: spec.config.num_workers
  29. "batch_size": lambda spec: random.uniform(1, 1000)
  30. Finally, to support defining specs in plain JSON / YAML, grid search
  31. and lambda functions can also be defined alternatively as follows:
  32. "activation": {"grid_search": ["relu", "tanh"]}
  33. "cpu": {"eval": "spec.config.num_workers"}
  34. Use `format_vars` to format the returned dict of hyperparameters.
  35. Yields:
  36. (Dict of resolved variables, Spec object)
  37. """
  38. for resolved_vars, spec in _generate_variants_internal(
  39. unresolved_spec,
  40. constant_grid_search=constant_grid_search,
  41. random_state=random_state,
  42. ):
  43. assert not _unresolved_values(spec)
  44. yield resolved_vars, spec
  45. @PublicAPI(stability="beta")
  46. def grid_search(values: Iterable) -> Dict[str, Iterable]:
  47. """Specify a grid of values to search over.
  48. Values specified in a grid search are guaranteed to be sampled.
  49. If multiple grid search variables are defined, they are combined with the
  50. combinatorial product. This means every possible combination of values will
  51. be sampled.
  52. Example:
  53. >>> from ray import tune
  54. >>> param_space={
  55. ... "x": tune.grid_search([10, 20]),
  56. ... "y": tune.grid_search(["a", "b", "c"])
  57. ... }
  58. This will create a grid of 6 samples:
  59. ``{"x": 10, "y": "a"}``, ``{"x": 10, "y": "b"}``, etc.
  60. When specifying ``num_samples`` in the
  61. :class:`TuneConfig <ray.tune.tune_config.TuneConfig>`, this will specify
  62. the number of random samples per grid search combination.
  63. For instance, in the example above, if ``num_samples=4``,
  64. a total of 24 trials will be started -
  65. 4 trials for each of the 6 grid search combinations.
  66. Args:
  67. values: An iterable whose parameters will be used for creating a trial grid.
  68. """
  69. return {"grid_search": values}
  70. _STANDARD_IMPORTS = {
  71. "random": random,
  72. "np": numpy,
  73. }
  74. _MAX_RESOLUTION_PASSES = 20
  75. def _resolve_nested_dict(nested_dict: Dict) -> Dict[Tuple, Any]:
  76. """Flattens a nested dict by joining keys into tuple of paths.
  77. Can then be passed into `format_vars`.
  78. """
  79. res = {}
  80. for k, v in nested_dict.items():
  81. if isinstance(v, dict):
  82. for k_, v_ in _resolve_nested_dict(v).items():
  83. res[(k,) + k_] = v_
  84. else:
  85. res[(k,)] = v
  86. return res
  87. @DeveloperAPI
  88. def format_vars(resolved_vars: Dict) -> str:
  89. """Format variables to be used as experiment tags.
  90. Experiment tags are used in directory names, so this method makes sure
  91. the resulting tags can be legally used in directory names on all systems.
  92. The input to this function is a dict of the form
  93. ``{("nested", "config", "path"): "value"}``. The output will be a comma
  94. separated string of the form ``last_key=value``, so in this example
  95. ``path=value``.
  96. Note that the sanitizing implies that empty strings are possible return
  97. values. This is expected and acceptable, as it is not a common case and
  98. the resulting directory names will still be valid.
  99. Args:
  100. resolved_vars: Dictionary mapping from config path tuples to a value.
  101. Returns:
  102. Comma-separated key=value string.
  103. """
  104. vars = resolved_vars.copy()
  105. # TrialRunner already has these in the experiment_tag
  106. for v in ["run", "env", "resources_per_trial"]:
  107. vars.pop(v, None)
  108. return ",".join(
  109. f"{_clean_value(k[-1])}={_clean_value(v)}" for k, v in sorted(vars.items())
  110. )
  111. def _flatten_resolved_vars(resolved_vars: Dict) -> Dict:
  112. """Formats the resolved variable dict into a mapping of (str -> value)."""
  113. flattened_resolved_vars_dict = {}
  114. for pieces, value in resolved_vars.items():
  115. if pieces[0] == "config":
  116. pieces = pieces[1:]
  117. pieces = [str(piece) for piece in pieces]
  118. flattened_resolved_vars_dict["/".join(pieces)] = value
  119. return flattened_resolved_vars_dict
  120. def _clean_value(value: Any) -> str:
  121. """Format floats and replace invalid string characters with ``_``."""
  122. if isinstance(value, float):
  123. return f"{value:.4f}"
  124. else:
  125. # Define an invalid alphabet, which is the inverse of the
  126. # stated regex characters
  127. invalid_alphabet = r"[^a-zA-Z0-9_-]+"
  128. return re.sub(invalid_alphabet, "_", str(value)).strip("_")
  129. @DeveloperAPI
  130. def parse_spec_vars(
  131. spec: Dict,
  132. ) -> Tuple[List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]], List[Tuple[Tuple, Any]]]:
  133. resolved, unresolved = _split_resolved_unresolved_values(spec)
  134. resolved_vars = list(resolved.items())
  135. if not unresolved:
  136. return resolved_vars, [], []
  137. grid_vars = []
  138. domain_vars = []
  139. for path, value in unresolved.items():
  140. if value.is_grid():
  141. grid_vars.append((path, value))
  142. else:
  143. domain_vars.append((path, value))
  144. grid_vars.sort()
  145. return resolved_vars, domain_vars, grid_vars
  146. def _count_spec_samples(spec: Dict, num_samples=1) -> int:
  147. """Count samples for a specific spec"""
  148. _, domain_vars, grid_vars = parse_spec_vars(spec)
  149. grid_count = 1
  150. for path, domain in grid_vars:
  151. grid_count *= len(domain.categories)
  152. return num_samples * grid_count
  153. def _count_variants(spec: Dict, presets: Optional[List[Dict]] = None) -> int:
  154. # Helper function: Deep update dictionary
  155. def deep_update(d, u):
  156. for k, v in u.items():
  157. if isinstance(v, Mapping):
  158. d[k] = deep_update(d.get(k, {}), v)
  159. else:
  160. d[k] = v
  161. return d
  162. total_samples = 0
  163. total_num_samples = spec.get("num_samples", 1)
  164. # For each preset, overwrite the spec and count the samples generated
  165. # for this preset
  166. for preset in presets:
  167. preset_spec = copy.deepcopy(spec)
  168. deep_update(preset_spec["config"], preset)
  169. total_samples += _count_spec_samples(preset_spec, 1)
  170. total_num_samples -= 1
  171. # Add the remaining samples
  172. if total_num_samples > 0:
  173. total_samples += _count_spec_samples(spec, total_num_samples)
  174. return total_samples
  175. def _generate_variants_internal(
  176. spec: Dict, constant_grid_search: bool = False, random_state: "RandomState" = None
  177. ) -> Tuple[Dict, Dict]:
  178. spec = copy.deepcopy(spec)
  179. _, domain_vars, grid_vars = parse_spec_vars(spec)
  180. if not domain_vars and not grid_vars:
  181. yield {}, spec
  182. return
  183. # Variables to resolve
  184. to_resolve = domain_vars
  185. all_resolved = True
  186. if constant_grid_search:
  187. # In this path, we first sample random variables and keep them constant
  188. # for grid search.
  189. # `_resolve_domain_vars` will alter `spec` directly
  190. all_resolved, resolved_vars = _resolve_domain_vars(
  191. spec, domain_vars, allow_fail=True, random_state=random_state
  192. )
  193. if not all_resolved:
  194. # Not all variables have been resolved, but remove those that have
  195. # from the `to_resolve` list.
  196. to_resolve = [(r, d) for r, d in to_resolve if r not in resolved_vars]
  197. grid_search = _grid_search_generator(spec, grid_vars)
  198. for resolved_spec in grid_search:
  199. if not constant_grid_search or not all_resolved:
  200. # In this path, we sample the remaining random variables
  201. _, resolved_vars = _resolve_domain_vars(
  202. resolved_spec, to_resolve, random_state=random_state
  203. )
  204. for resolved, spec in _generate_variants_internal(
  205. resolved_spec,
  206. constant_grid_search=constant_grid_search,
  207. random_state=random_state,
  208. ):
  209. for path, value in grid_vars:
  210. resolved_vars[path] = _get_value(spec, path)
  211. for k, v in resolved.items():
  212. if (
  213. k in resolved_vars
  214. and v != resolved_vars[k]
  215. and _is_resolved(resolved_vars[k])
  216. ):
  217. raise ValueError(
  218. "The variable `{}` could not be unambiguously "
  219. "resolved to a single value. Consider simplifying "
  220. "your configuration.".format(k)
  221. )
  222. resolved_vars[k] = v
  223. yield resolved_vars, spec
  224. def _get_preset_variants(
  225. spec: Dict,
  226. config: Dict,
  227. constant_grid_search: bool = False,
  228. random_state: "RandomState" = None,
  229. ):
  230. """Get variants according to a spec, initialized with a config.
  231. Variables from the spec are overwritten by the variables in the config.
  232. Thus, we may end up with less sampled parameters.
  233. This function also checks if values used to overwrite search space
  234. parameters are valid, and logs a warning if not.
  235. """
  236. spec = copy.deepcopy(spec)
  237. resolved, _, _ = parse_spec_vars(config)
  238. for path, val in resolved:
  239. try:
  240. domain = _get_value(spec["config"], path)
  241. if isinstance(domain, dict):
  242. if "grid_search" in domain:
  243. domain = Categorical(domain["grid_search"])
  244. else:
  245. # If users want to overwrite an entire subdict,
  246. # let them do it.
  247. domain = None
  248. except IndexError as exc:
  249. raise ValueError(
  250. f"Pre-set config key `{'/'.join(path)}` does not correspond "
  251. f"to a valid key in the search space definition. Please add "
  252. f"this path to the `param_space` variable passed to `tune.Tuner()`."
  253. ) from exc
  254. if domain:
  255. if isinstance(domain, Domain):
  256. if not domain.is_valid(val):
  257. logger.warning(
  258. f"Pre-set value `{val}` is not within valid values of "
  259. f"parameter `{'/'.join(path)}`: {domain.domain_str}"
  260. )
  261. else:
  262. # domain is actually a fixed value
  263. if domain != val:
  264. logger.warning(
  265. f"Pre-set value `{val}` is not equal to the value of "
  266. f"parameter `{'/'.join(path)}`: {domain}"
  267. )
  268. assign_value(spec["config"], path, val)
  269. return _generate_variants_internal(
  270. spec, constant_grid_search=constant_grid_search, random_state=random_state
  271. )
  272. @DeveloperAPI
  273. def assign_value(spec: Dict, path: Tuple, value: Any):
  274. """Assigns a value to a nested dictionary.
  275. Handles the special case of tuples, in which case the tuples
  276. will be re-constructed to accommodate the updated value.
  277. """
  278. parent_spec = None
  279. parent_key = None
  280. for k in path[:-1]:
  281. parent_spec = spec
  282. parent_key = k
  283. spec = spec[k]
  284. key = path[-1]
  285. if not isinstance(spec, tuple):
  286. # spec is mutable. Just assign the value.
  287. spec[key] = value
  288. else:
  289. if parent_spec is None:
  290. raise ValueError("Cannot assign value to a tuple.")
  291. assert isinstance(key, int), "Tuple key must be an int."
  292. # Special handling since tuples are immutable.
  293. parent_spec[parent_key] = spec[:key] + (value,) + spec[key + 1 :]
  294. def _get_value(spec: Dict, path: Tuple) -> Any:
  295. for k in path:
  296. spec = spec[k]
  297. return spec
  298. def _resolve_domain_vars(
  299. spec: Dict,
  300. domain_vars: List[Tuple[Tuple, Domain]],
  301. allow_fail: bool = False,
  302. random_state: "RandomState" = None,
  303. ) -> Tuple[bool, Dict]:
  304. resolved = {}
  305. error = True
  306. num_passes = 0
  307. while error and num_passes < _MAX_RESOLUTION_PASSES:
  308. num_passes += 1
  309. error = False
  310. for path, domain in domain_vars:
  311. if path in resolved:
  312. continue
  313. try:
  314. value = domain.sample(
  315. _UnresolvedAccessGuard(spec), random_state=random_state
  316. )
  317. except RecursiveDependencyError as e:
  318. error = e
  319. except Exception:
  320. raise ValueError(
  321. "Failed to evaluate expression: {}: {}".format(path, domain)
  322. )
  323. else:
  324. assign_value(spec, path, value)
  325. resolved[path] = value
  326. if error:
  327. if not allow_fail:
  328. raise error
  329. else:
  330. return False, resolved
  331. return True, resolved
  332. def _grid_search_generator(
  333. unresolved_spec: Dict, grid_vars: List
  334. ) -> Generator[Dict, None, None]:
  335. value_indices = [0] * len(grid_vars)
  336. def increment(i):
  337. value_indices[i] += 1
  338. if value_indices[i] >= len(grid_vars[i][1]):
  339. value_indices[i] = 0
  340. if i + 1 < len(value_indices):
  341. return increment(i + 1)
  342. else:
  343. return True
  344. return False
  345. if not grid_vars:
  346. yield unresolved_spec
  347. return
  348. while value_indices[-1] < len(grid_vars[-1][1]):
  349. spec = copy.deepcopy(unresolved_spec)
  350. for i, (path, values) in enumerate(grid_vars):
  351. assign_value(spec, path, values[value_indices[i]])
  352. yield spec
  353. if grid_vars:
  354. done = increment(0)
  355. if done:
  356. break
  357. def _is_resolved(v) -> bool:
  358. resolved, _ = _try_resolve(v)
  359. return resolved
  360. def _try_resolve(v) -> Tuple[bool, Any]:
  361. if isinstance(v, Domain):
  362. # Domain to sample from
  363. return False, v
  364. elif isinstance(v, dict) and len(v) == 1 and "eval" in v:
  365. # Lambda function in eval syntax
  366. return False, Function(
  367. lambda spec: eval(v["eval"], _STANDARD_IMPORTS, {"spec": spec})
  368. )
  369. elif isinstance(v, dict) and len(v) == 1 and "grid_search" in v:
  370. # Grid search values
  371. grid_values = v["grid_search"]
  372. return False, Categorical(grid_values).grid()
  373. return True, v
  374. def _split_resolved_unresolved_values(
  375. spec: Dict,
  376. ) -> Tuple[Dict[Tuple, Any], Dict[Tuple, Any]]:
  377. resolved_vars = {}
  378. unresolved_vars = {}
  379. for k, v in spec.items():
  380. resolved, v = _try_resolve(v)
  381. if not resolved:
  382. unresolved_vars[(k,)] = v
  383. elif isinstance(v, dict):
  384. # Recurse into a dict
  385. (
  386. _resolved_children,
  387. _unresolved_children,
  388. ) = _split_resolved_unresolved_values(v)
  389. for path, value in _resolved_children.items():
  390. resolved_vars[(k,) + path] = value
  391. for path, value in _unresolved_children.items():
  392. unresolved_vars[(k,) + path] = value
  393. elif isinstance(v, (list, tuple)):
  394. # Recurse into a list
  395. for i, elem in enumerate(v):
  396. (
  397. _resolved_children,
  398. _unresolved_children,
  399. ) = _split_resolved_unresolved_values({i: elem})
  400. for path, value in _resolved_children.items():
  401. resolved_vars[(k,) + path] = value
  402. for path, value in _unresolved_children.items():
  403. unresolved_vars[(k,) + path] = value
  404. else:
  405. resolved_vars[(k,)] = v
  406. return resolved_vars, unresolved_vars
  407. def _unresolved_values(spec: Dict) -> Dict[Tuple, Any]:
  408. return _split_resolved_unresolved_values(spec)[1]
  409. def _has_unresolved_values(spec: Dict) -> bool:
  410. return True if _unresolved_values(spec) else False
  411. class _UnresolvedAccessGuard(dict):
  412. def __init__(self, *args, **kwds):
  413. super(_UnresolvedAccessGuard, self).__init__(*args, **kwds)
  414. self.__dict__ = self
  415. def __getattribute__(self, item):
  416. value = dict.__getattribute__(self, item)
  417. if not _is_resolved(value):
  418. raise RecursiveDependencyError(
  419. "`{}` recursively depends on {}".format(item, value)
  420. )
  421. elif isinstance(value, dict):
  422. return _UnresolvedAccessGuard(value)
  423. else:
  424. return value
  425. @DeveloperAPI
  426. class RecursiveDependencyError(Exception):
  427. def __init__(self, msg: str):
  428. Exception.__init__(self, msg)