tune_config.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import datetime
  2. from dataclasses import dataclass
  3. from enum import Enum
  4. from typing import Callable, Optional, Union
  5. from ray.train.constants import _DEPRECATED_VALUE
  6. from ray.tune.experiment.trial import Trial
  7. from ray.tune.schedulers import TrialScheduler
  8. from ray.tune.search import SearchAlgorithm, Searcher
  9. from ray.util.annotations import DeveloperAPI, PublicAPI
  10. @dataclass
  11. @PublicAPI(stability="beta")
  12. class TuneConfig:
  13. """Tune specific configs.
  14. Args:
  15. metric: Metric to optimize. This metric should be reported
  16. with `tune.report()`. If set, will be passed to the search
  17. algorithm and scheduler.
  18. mode: Must be one of [min, max]. Determines whether objective is
  19. minimizing or maximizing the metric attribute. If set, will be
  20. passed to the search algorithm and scheduler.
  21. search_alg: Search algorithm for optimization. Default to
  22. random search.
  23. scheduler: Scheduler for executing the experiment.
  24. Choose among FIFO (default), MedianStopping,
  25. AsyncHyperBand, HyperBand and PopulationBasedTraining. Refer to
  26. ray.tune.schedulers for more options.
  27. num_samples: Number of times to sample from the
  28. hyperparameter space. Defaults to 1. If `grid_search` is
  29. provided as an argument, the grid will be repeated
  30. `num_samples` of times. If this is -1, (virtually) infinite
  31. samples are generated until a stopping condition is met.
  32. max_concurrent_trials: Maximum number of trials to run
  33. concurrently. Must be non-negative. If None or 0, no limit will
  34. be applied. This is achieved by wrapping the ``search_alg`` in
  35. a :class:`ConcurrencyLimiter`, and thus setting this argument
  36. will raise an exception if the ``search_alg`` is already a
  37. :class:`ConcurrencyLimiter`. Defaults to None.
  38. time_budget_s: Global time budget in
  39. seconds after which all trials are stopped. Can also be a
  40. ``datetime.timedelta`` object.
  41. reuse_actors: Whether to reuse actors between different trials
  42. when possible. This can drastically speed up experiments that start
  43. and stop actors often (e.g., PBT in time-multiplexing mode). This
  44. requires trials to have the same resource requirements.
  45. Defaults to ``False``.
  46. trial_name_creator: Optional function that takes in a Trial and returns
  47. its name (i.e. its string representation). Be sure to include some unique
  48. identifier (such as `Trial.trial_id`) in each trial's name.
  49. NOTE: This API is in alpha and subject to change.
  50. trial_dirname_creator: Optional function that takes in a trial and
  51. generates its trial directory name as a string. Be sure to include some
  52. unique identifier (such as `Trial.trial_id`) is used in each trial's
  53. directory name. Otherwise, trials could overwrite artifacts and checkpoints
  54. of other trials. The return value cannot be a path.
  55. NOTE: This API is in alpha and subject to change.
  56. chdir_to_trial_dir: Deprecated. Set the `RAY_CHDIR_TO_TRIAL_DIR` env var instead
  57. """
  58. # Currently this is not at feature parity with `tune.run`, nor should it be.
  59. # The goal is to reach a fine balance between API flexibility and conciseness.
  60. # We should carefully introduce arguments here instead of just dumping everything.
  61. mode: Optional[str] = None
  62. metric: Optional[str] = None
  63. search_alg: Optional[Union[Searcher, SearchAlgorithm]] = None
  64. scheduler: Optional[TrialScheduler] = None
  65. num_samples: int = 1
  66. max_concurrent_trials: Optional[int] = None
  67. time_budget_s: Optional[Union[int, float, datetime.timedelta]] = None
  68. reuse_actors: bool = False
  69. trial_name_creator: Optional[Callable[[Trial], str]] = None
  70. trial_dirname_creator: Optional[Callable[[Trial], str]] = None
  71. chdir_to_trial_dir: bool = _DEPRECATED_VALUE
  72. @DeveloperAPI
  73. @dataclass
  74. class ResumeConfig:
  75. """[Experimental] This config is used to specify how to resume Tune trials."""
  76. class ResumeType(Enum):
  77. """An enumeration to define resume types for various trial states.
  78. Members:
  79. RESUME: Resume from the latest checkpoint.
  80. RESTART: Restart from the beginning (with no checkpoint).
  81. SKIP: Skip this trial when resuming by treating it as terminated.
  82. """
  83. RESUME = "resume"
  84. RESTART = "restart"
  85. SKIP = "skip"
  86. finished: str = ResumeType.SKIP
  87. unfinished: str = ResumeType.RESUME
  88. errored: str = ResumeType.SKIP