__init__.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import inspect
  2. from ray._common.utils import get_function_args
  3. from ray.tune.schedulers.async_hyperband import ASHAScheduler, AsyncHyperBandScheduler
  4. from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
  5. from ray.tune.schedulers.hyperband import HyperBandScheduler
  6. from ray.tune.schedulers.median_stopping_rule import MedianStoppingRule
  7. from ray.tune.schedulers.pbt import (
  8. PopulationBasedTraining,
  9. PopulationBasedTrainingReplay,
  10. )
  11. from ray.tune.schedulers.resource_changing_scheduler import ResourceChangingScheduler
  12. from ray.tune.schedulers.trial_scheduler import FIFOScheduler, TrialScheduler
  13. from ray.util import PublicAPI
  14. def _pb2_importer():
  15. # PB2 is imported lazily since it has additional dependencies.
  16. from ray.tune.schedulers.pb2 import PB2
  17. return PB2
  18. # Values in this dictionary will be one two kinds:
  19. # class of the scheduler object to create
  20. # wrapper function to support a lazy import of the scheduler class
  21. SCHEDULER_IMPORT = {
  22. "fifo": FIFOScheduler,
  23. "async_hyperband": AsyncHyperBandScheduler,
  24. "asynchyperband": AsyncHyperBandScheduler,
  25. "median_stopping_rule": MedianStoppingRule,
  26. "medianstopping": MedianStoppingRule,
  27. "hyperband": HyperBandScheduler,
  28. "hb_bohb": HyperBandForBOHB,
  29. "pbt": PopulationBasedTraining,
  30. "pbt_replay": PopulationBasedTrainingReplay,
  31. "pb2": _pb2_importer,
  32. "resource_changing": ResourceChangingScheduler,
  33. }
  34. @PublicAPI(stability="beta")
  35. def create_scheduler(
  36. scheduler,
  37. **kwargs,
  38. ):
  39. """Instantiate a scheduler based on the given string.
  40. This is useful for swapping between different schedulers.
  41. Args:
  42. scheduler: The scheduler to use.
  43. **kwargs: Scheduler parameters.
  44. These keyword arguments will be passed to the initialization
  45. function of the chosen scheduler.
  46. Returns:
  47. ray.tune.schedulers.trial_scheduler.TrialScheduler: The scheduler.
  48. Example:
  49. >>> from ray import tune
  50. >>> pbt_kwargs = {}
  51. >>> scheduler = tune.create_scheduler('pbt', **pbt_kwargs) # doctest: +SKIP
  52. """
  53. scheduler = scheduler.lower()
  54. if scheduler not in SCHEDULER_IMPORT:
  55. raise ValueError(
  56. f"The `scheduler` argument must be one of "
  57. f"{list(SCHEDULER_IMPORT)}. "
  58. f"Got: {scheduler}"
  59. )
  60. SchedulerClass = SCHEDULER_IMPORT[scheduler]
  61. if inspect.isfunction(SchedulerClass):
  62. # invoke the wrapper function to retrieve class
  63. SchedulerClass = SchedulerClass()
  64. scheduler_args = get_function_args(SchedulerClass)
  65. trimmed_kwargs = {k: v for k, v in kwargs.items() if k in scheduler_args}
  66. return SchedulerClass(**trimmed_kwargs)
  67. __all__ = [
  68. "TrialScheduler",
  69. "HyperBandScheduler",
  70. "AsyncHyperBandScheduler",
  71. "ASHAScheduler",
  72. "MedianStoppingRule",
  73. "FIFOScheduler",
  74. "PopulationBasedTraining",
  75. "PopulationBasedTrainingReplay",
  76. "HyperBandForBOHB",
  77. "ResourceChangingScheduler",
  78. ]