config.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from dataclasses import dataclass
  2. from ray.air.config import (
  3. CheckpointConfig as _CheckpointConfig,
  4. FailureConfig as _FailureConfig,
  5. RunConfig as _RunConfig,
  6. )
  7. from ray.train.constants import (
  8. V2_MIGRATION_GUIDE_MESSAGE,
  9. _v2_migration_warnings_enabled,
  10. )
  11. from ray.train.utils import _copy_doc, _log_deprecation_warning
  12. # NOTE: This is just a pass-through wrapper around `ray.tune.RunConfig`
  13. # in order to detect whether the import module was correct (e.g. `ray.tune.RunConfig`).
  14. @dataclass
  15. @_copy_doc(_CheckpointConfig)
  16. class CheckpointConfig(_CheckpointConfig):
  17. pass
  18. @dataclass
  19. @_copy_doc(_FailureConfig)
  20. class FailureConfig(_FailureConfig):
  21. pass
  22. @dataclass
  23. @_copy_doc(_RunConfig)
  24. class RunConfig(_RunConfig):
  25. def __post_init__(self):
  26. self.checkpoint_config = self.checkpoint_config or CheckpointConfig()
  27. self.failure_config = self.failure_config or FailureConfig()
  28. super().__post_init__()
  29. if not isinstance(self.checkpoint_config, CheckpointConfig):
  30. if _v2_migration_warnings_enabled():
  31. _log_deprecation_warning(
  32. "The `CheckpointConfig` class should be imported from `ray.tune` "
  33. "when passing it to the Tuner. Please update your imports."
  34. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  35. )
  36. if not isinstance(self.failure_config, FailureConfig):
  37. if _v2_migration_warnings_enabled():
  38. _log_deprecation_warning(
  39. "The `FailureConfig` class should be imported from `ray.tune` "
  40. "when passing it to the Tuner. Please update your imports."
  41. f"{V2_MIGRATION_GUIDE_MESSAGE}"
  42. )