basic_variant.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422
  1. import copy
  2. import itertools
  3. import os
  4. import uuid
  5. import warnings
  6. from pathlib import Path
  7. from typing import TYPE_CHECKING, Dict, List, Optional, Union
  8. import numpy as np
  9. from ray.air._internal.usage import tag_searcher
  10. from ray.tune.error import TuneError
  11. from ray.tune.experiment.config_parser import _create_trial_from_spec, _make_parser
  12. from ray.tune.search.sample import _BackwardsCompatibleNumpyRng, np_random_generator
  13. from ray.tune.search.search_algorithm import SearchAlgorithm
  14. from ray.tune.search.variant_generator import (
  15. _count_spec_samples,
  16. _count_variants,
  17. _flatten_resolved_vars,
  18. _get_preset_variants,
  19. format_vars,
  20. generate_variants,
  21. )
  22. from ray.tune.utils.util import _atomic_save, _load_newest_checkpoint
  23. from ray.util import PublicAPI
  24. if TYPE_CHECKING:
  25. from ray.tune.experiment import Experiment
  26. SERIALIZATION_THRESHOLD = 1e6
  27. class _VariantIterator:
  28. """Iterates over generated variants from the search space.
  29. This object also toggles between lazy evaluation and
  30. eager evaluation of samples. If lazy evaluation is enabled,
  31. this object cannot be serialized.
  32. """
  33. def __init__(self, iterable, lazy_eval=False):
  34. self.lazy_eval = lazy_eval
  35. self.iterable = iterable
  36. self._has_next = True
  37. if lazy_eval:
  38. self._load_value()
  39. else:
  40. self.iterable = list(iterable)
  41. self._has_next = bool(self.iterable)
  42. def _load_value(self):
  43. try:
  44. self.next_value = next(self.iterable)
  45. except StopIteration:
  46. self._has_next = False
  47. def has_next(self):
  48. return self._has_next
  49. def __next__(self):
  50. if self.lazy_eval:
  51. current_value = self.next_value
  52. self._load_value()
  53. return current_value
  54. current_value = self.iterable.pop(0)
  55. self._has_next = bool(self.iterable)
  56. return current_value
  57. class _TrialIterator:
  58. """Generates trials from the spec.
  59. Args:
  60. uuid_prefix: Used in creating the trial name.
  61. num_samples: Number of samples from distribution
  62. (same as tune.TuneConfig).
  63. unresolved_spec: Experiment specification
  64. that might have unresolved distributions.
  65. constant_grid_search: Should random variables be sampled
  66. first before iterating over grid variants (True) or not (False).
  67. points_to_evaluate: Configurations that will be tried out without sampling.
  68. lazy_eval: Whether variants should be generated
  69. lazily or eagerly. This is toggled depending
  70. on the size of the grid search.
  71. start: index at which to start counting trials.
  72. random_state (int | np.random.Generator | np.random.RandomState):
  73. Seed or numpy random generator to use for reproducible results.
  74. If None (default), will use the global numpy random generator
  75. (``np.random``). Please note that full reproducibility cannot
  76. be guaranteed in a distributed environment.
  77. """
  78. def __init__(
  79. self,
  80. uuid_prefix: str,
  81. num_samples: int,
  82. unresolved_spec: dict,
  83. constant_grid_search: bool = False,
  84. points_to_evaluate: Optional[List] = None,
  85. lazy_eval: bool = False,
  86. start: int = 0,
  87. random_state: Optional[
  88. Union[int, "np_random_generator", np.random.RandomState]
  89. ] = None,
  90. ):
  91. self.parser = _make_parser()
  92. self.num_samples = num_samples
  93. self.uuid_prefix = uuid_prefix
  94. self.num_samples_left = num_samples
  95. self.unresolved_spec = unresolved_spec
  96. self.constant_grid_search = constant_grid_search
  97. self.points_to_evaluate = points_to_evaluate or []
  98. self.num_points_to_evaluate = len(self.points_to_evaluate)
  99. self.counter = start
  100. self.lazy_eval = lazy_eval
  101. self.variants = None
  102. self.random_state = random_state
  103. def create_trial(self, resolved_vars, spec):
  104. trial_id = self.uuid_prefix + ("%05d" % self.counter)
  105. experiment_tag = str(self.counter)
  106. # Always append resolved vars to experiment tag?
  107. if resolved_vars:
  108. experiment_tag += "_{}".format(format_vars(resolved_vars))
  109. self.counter += 1
  110. return _create_trial_from_spec(
  111. spec,
  112. self.parser,
  113. evaluated_params=_flatten_resolved_vars(resolved_vars),
  114. trial_id=trial_id,
  115. experiment_tag=experiment_tag,
  116. )
  117. def __next__(self):
  118. """Generates Trial objects with the variant generation process.
  119. Uses a fixed point iteration to resolve variants. All trials
  120. should be able to be generated at once.
  121. See also: `ray.tune.search.variant_generator`.
  122. Returns:
  123. Trial object
  124. """
  125. if "run" not in self.unresolved_spec:
  126. raise TuneError("Must specify `run` in {}".format(self.unresolved_spec))
  127. if self.variants and self.variants.has_next():
  128. # This block will be skipped upon instantiation.
  129. # `variants` will be set later after the first loop.
  130. resolved_vars, spec = next(self.variants)
  131. return self.create_trial(resolved_vars, spec)
  132. if self.points_to_evaluate:
  133. config = self.points_to_evaluate.pop(0)
  134. self.num_samples_left -= 1
  135. self.variants = _VariantIterator(
  136. _get_preset_variants(
  137. self.unresolved_spec,
  138. config,
  139. constant_grid_search=self.constant_grid_search,
  140. random_state=self.random_state,
  141. ),
  142. lazy_eval=self.lazy_eval,
  143. )
  144. resolved_vars, spec = next(self.variants)
  145. return self.create_trial(resolved_vars, spec)
  146. elif self.num_samples_left > 0:
  147. self.variants = _VariantIterator(
  148. generate_variants(
  149. self.unresolved_spec,
  150. constant_grid_search=self.constant_grid_search,
  151. random_state=self.random_state,
  152. ),
  153. lazy_eval=self.lazy_eval,
  154. )
  155. self.num_samples_left -= 1
  156. resolved_vars, spec = next(self.variants)
  157. return self.create_trial(resolved_vars, spec)
  158. else:
  159. raise StopIteration
  160. def __iter__(self):
  161. return self
  162. @PublicAPI
  163. class BasicVariantGenerator(SearchAlgorithm):
  164. """Uses Tune's variant generation for resolving variables.
  165. This is the default search algorithm used if no other search algorithm
  166. is specified.
  167. Args:
  168. points_to_evaluate: Initial parameter suggestions to be run
  169. first. This is for when you already have some good parameters
  170. you want to run first to help the algorithm make better suggestions
  171. for future parameters. Needs to be a list of dicts containing the
  172. configurations.
  173. max_concurrent: Maximum number of concurrently running trials.
  174. If 0 (default), no maximum is enforced.
  175. constant_grid_search: If this is set to ``True``, Ray Tune will
  176. *first* try to sample random values and keep them constant over
  177. grid search parameters. If this is set to ``False`` (default),
  178. Ray Tune will sample new random parameters in each grid search
  179. condition.
  180. random_state:
  181. Seed or numpy random generator to use for reproducible results.
  182. If None (default), will use the global numpy random generator
  183. (``np.random``). Please note that full reproducibility cannot
  184. be guaranteed in a distributed environment.
  185. Example:
  186. .. code-block:: python
  187. from ray import tune
  188. # This will automatically use the `BasicVariantGenerator`
  189. tuner = tune.Tuner(
  190. lambda config: config["a"] + config["b"],
  191. tune_config=tune.TuneConfig(
  192. num_samples=4
  193. ),
  194. param_space={
  195. "a": tune.grid_search([1, 2]),
  196. "b": tune.randint(0, 3)
  197. },
  198. )
  199. tuner.fit()
  200. In the example above, 8 trials will be generated: For each sample
  201. (``4``), each of the grid search variants for ``a`` will be sampled
  202. once. The ``b`` parameter will be sampled randomly.
  203. The generator accepts a pre-set list of points that should be evaluated.
  204. The points will replace the first samples of each experiment passed to
  205. the ``BasicVariantGenerator``.
  206. Each point will replace one sample of the specified ``num_samples``. If
  207. grid search variables are overwritten with the values specified in the
  208. presets, the number of samples will thus be reduced.
  209. Example:
  210. .. code-block:: python
  211. from ray import tune
  212. from ray.tune.search.basic_variant import BasicVariantGenerator
  213. tuner = tune.Tuner(
  214. lambda config: config["a"] + config["b"],
  215. tune_config=tune.TuneConfig(
  216. search_alg=BasicVariantGenerator(points_to_evaluate=[
  217. {"a": 2, "b": 2},
  218. {"a": 1},
  219. {"b": 2}
  220. ]),
  221. num_samples=4
  222. ),
  223. param_space={
  224. "a": tune.grid_search([1, 2]),
  225. "b": tune.randint(0, 3)
  226. },
  227. )
  228. tuner.fit()
  229. The example above will produce six trials via four samples:
  230. - The first sample will produce one trial with ``a=2`` and ``b=2``.
  231. - The second sample will produce one trial with ``a=1`` and ``b`` sampled
  232. randomly
  233. - The third sample will produce two trials, one for each grid search
  234. value of ``a``. It will be ``b=2`` for both of these trials.
  235. - The fourth sample will produce two trials, one for each grid search
  236. value of ``a``. ``b`` will be sampled randomly and independently for
  237. both of these trials.
  238. """
  239. CKPT_FILE_TMPL = "basic-variant-state-{}.json"
  240. def __init__(
  241. self,
  242. points_to_evaluate: Optional[List[Dict]] = None,
  243. max_concurrent: int = 0,
  244. constant_grid_search: bool = False,
  245. random_state: Optional[
  246. Union[int, "np_random_generator", np.random.RandomState]
  247. ] = None,
  248. ):
  249. tag_searcher(self)
  250. self._trial_generator = []
  251. self._iterators = []
  252. self._trial_iter = None
  253. self._finished = False
  254. self._random_state = _BackwardsCompatibleNumpyRng(random_state)
  255. self._points_to_evaluate = points_to_evaluate or []
  256. # Unique prefix for all trials generated, e.g., trial ids start as
  257. # 2f1e_00001, 2f1ef_00002, 2f1ef_0003, etc. Overridable for testing.
  258. force_test_uuid = os.environ.get("_TEST_TUNE_TRIAL_UUID")
  259. if force_test_uuid:
  260. self._uuid_prefix = force_test_uuid + "_"
  261. else:
  262. self._uuid_prefix = str(uuid.uuid1().hex)[:5] + "_"
  263. self._total_samples = 0
  264. self.max_concurrent = max_concurrent
  265. self._constant_grid_search = constant_grid_search
  266. self._live_trials = set()
  267. @property
  268. def total_samples(self):
  269. return self._total_samples
  270. def add_configurations(
  271. self, experiments: Union["Experiment", List["Experiment"], Dict[str, Dict]]
  272. ):
  273. """Chains generator given experiment specifications.
  274. Arguments:
  275. experiments: Experiments to run.
  276. """
  277. from ray.tune.experiment import _convert_to_experiment_list
  278. experiment_list = _convert_to_experiment_list(experiments)
  279. for experiment in experiment_list:
  280. grid_vals = _count_spec_samples(experiment.spec, num_samples=1)
  281. lazy_eval = grid_vals > SERIALIZATION_THRESHOLD
  282. if lazy_eval:
  283. warnings.warn(
  284. f"The number of pre-generated samples ({grid_vals}) "
  285. "exceeds the serialization threshold "
  286. f"({int(SERIALIZATION_THRESHOLD)}). Resume ability is "
  287. "disabled. To fix this, reduce the number of "
  288. "dimensions/size of the provided grid search."
  289. )
  290. previous_samples = self._total_samples
  291. points_to_evaluate = copy.deepcopy(self._points_to_evaluate)
  292. self._total_samples += _count_variants(experiment.spec, points_to_evaluate)
  293. iterator = _TrialIterator(
  294. uuid_prefix=self._uuid_prefix,
  295. num_samples=experiment.spec.get("num_samples", 1),
  296. unresolved_spec=experiment.spec,
  297. constant_grid_search=self._constant_grid_search,
  298. points_to_evaluate=points_to_evaluate,
  299. lazy_eval=lazy_eval,
  300. start=previous_samples,
  301. random_state=self._random_state,
  302. )
  303. self._iterators.append(iterator)
  304. self._trial_generator = itertools.chain(self._trial_generator, iterator)
  305. def next_trial(self):
  306. """Provides one Trial object to be queued into the TrialRunner.
  307. Returns:
  308. Trial: Returns a single trial.
  309. """
  310. if self.is_finished():
  311. return None
  312. if self.max_concurrent > 0 and len(self._live_trials) >= self.max_concurrent:
  313. return None
  314. if not self._trial_iter:
  315. self._trial_iter = iter(self._trial_generator)
  316. try:
  317. trial = next(self._trial_iter)
  318. self._live_trials.add(trial.trial_id)
  319. return trial
  320. except StopIteration:
  321. self._trial_generator = []
  322. self._trial_iter = None
  323. self.set_finished()
  324. return None
  325. def on_trial_complete(
  326. self, trial_id: str, result: Optional[Dict] = None, error: bool = False
  327. ):
  328. if trial_id in self._live_trials:
  329. self._live_trials.remove(trial_id)
  330. def get_state(self):
  331. if any(iterator.lazy_eval for iterator in self._iterators):
  332. return False
  333. state = self.__dict__.copy()
  334. del state["_trial_generator"]
  335. return state
  336. def set_state(self, state):
  337. self.__dict__.update(state)
  338. for iterator in self._iterators:
  339. self._trial_generator = itertools.chain(self._trial_generator, iterator)
  340. def save_to_dir(self, dirpath, session_str):
  341. if any(iterator.lazy_eval for iterator in self._iterators):
  342. return False
  343. state_dict = self.get_state()
  344. file_name = self.CKPT_FILE_TMPL.format(session_str)
  345. _atomic_save(
  346. state=state_dict,
  347. checkpoint_dir=dirpath,
  348. file_name=file_name,
  349. tmp_file_name=f"tmp-{file_name}",
  350. )
  351. def has_checkpoint(self, dirpath: str):
  352. """Whether a checkpoint file exists within dirpath."""
  353. return any(Path(dirpath).glob(self.CKPT_FILE_TMPL.format("*")))
  354. def restore_from_dir(self, dirpath: str):
  355. """Restores self + searcher + search wrappers from dirpath."""
  356. state_dict = _load_newest_checkpoint(dirpath, self.CKPT_FILE_TMPL.format("*"))
  357. if not state_dict:
  358. raise RuntimeError("Unable to find checkpoint in {}.".format(dirpath))
  359. self.set_state(state_dict)