hyperparameter_search.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2023-present the HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from .integrations import (
  15. is_optuna_available,
  16. is_ray_tune_available,
  17. is_wandb_available,
  18. run_hp_search_optuna,
  19. run_hp_search_ray,
  20. run_hp_search_wandb,
  21. )
  22. from .trainer_utils import (
  23. HPSearchBackend,
  24. default_hp_space_optuna,
  25. default_hp_space_ray,
  26. default_hp_space_wandb,
  27. )
  28. from .utils import logging
  29. logger = logging.get_logger(__name__)
  30. class HyperParamSearchBackendBase:
  31. name: str
  32. pip_package: str | None = None
  33. @staticmethod
  34. def is_available():
  35. raise NotImplementedError
  36. def run(self, trainer, n_trials: int, direction: str, **kwargs):
  37. raise NotImplementedError
  38. def default_hp_space(self, trial):
  39. raise NotImplementedError
  40. def ensure_available(self):
  41. if not self.is_available():
  42. raise RuntimeError(
  43. f"You picked the {self.name} backend, but it is not installed. Run {self.pip_install()}."
  44. )
  45. @classmethod
  46. def pip_install(cls):
  47. return f"`pip install {cls.pip_package or cls.name}`"
  48. class OptunaBackend(HyperParamSearchBackendBase):
  49. name = "optuna"
  50. @staticmethod
  51. def is_available():
  52. return is_optuna_available()
  53. def run(self, trainer, n_trials: int, direction: str, **kwargs):
  54. return run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
  55. def default_hp_space(self, trial):
  56. return default_hp_space_optuna(trial)
  57. class RayTuneBackend(HyperParamSearchBackendBase):
  58. name = "ray"
  59. pip_package = "'ray[tune]'"
  60. @staticmethod
  61. def is_available():
  62. return is_ray_tune_available()
  63. def run(self, trainer, n_trials: int, direction: str, **kwargs):
  64. return run_hp_search_ray(trainer, n_trials, direction, **kwargs)
  65. def default_hp_space(self, trial):
  66. return default_hp_space_ray(trial)
  67. class WandbBackend(HyperParamSearchBackendBase):
  68. name = "wandb"
  69. @staticmethod
  70. def is_available():
  71. return is_wandb_available()
  72. def run(self, trainer, n_trials: int, direction: str, **kwargs):
  73. return run_hp_search_wandb(trainer, n_trials, direction, **kwargs)
  74. def default_hp_space(self, trial):
  75. return default_hp_space_wandb(trial)
  76. ALL_HYPERPARAMETER_SEARCH_BACKENDS = {
  77. HPSearchBackend(backend.name): backend for backend in [OptunaBackend, RayTuneBackend, WandbBackend]
  78. }
  79. def default_hp_search_backend() -> str:
  80. available_backends = [backend for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values() if backend.is_available()]
  81. if len(available_backends) > 0:
  82. name = available_backends[0].name
  83. if len(available_backends) > 1:
  84. logger.info(
  85. f"{len(available_backends)} hyperparameter search backends available. Using {name} as the default."
  86. )
  87. return name
  88. raise RuntimeError(
  89. "No hyperparameter search backend available.\n"
  90. + "\n".join(
  91. f" - To install {backend.name} run {backend.pip_install()}"
  92. for backend in ALL_HYPERPARAMETER_SEARCH_BACKENDS.values()
  93. )
  94. )