| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- from dataclasses import dataclass
- from ray.air.config import (
- CheckpointConfig as _CheckpointConfig,
- FailureConfig as _FailureConfig,
- RunConfig as _RunConfig,
- )
- from ray.train.constants import (
- V2_MIGRATION_GUIDE_MESSAGE,
- _v2_migration_warnings_enabled,
- )
- from ray.train.utils import _copy_doc, _log_deprecation_warning
- # NOTE: This is just a pass-through wrapper around `ray.tune.RunConfig`
- # in order to detect whether the import module was correct (e.g. `ray.tune.RunConfig`).
- @dataclass
- @_copy_doc(_CheckpointConfig)
- class CheckpointConfig(_CheckpointConfig):
- pass
- @dataclass
- @_copy_doc(_FailureConfig)
- class FailureConfig(_FailureConfig):
- pass
- @dataclass
- @_copy_doc(_RunConfig)
- class RunConfig(_RunConfig):
- def __post_init__(self):
- self.checkpoint_config = self.checkpoint_config or CheckpointConfig()
- self.failure_config = self.failure_config or FailureConfig()
- super().__post_init__()
- if not isinstance(self.checkpoint_config, CheckpointConfig):
- if _v2_migration_warnings_enabled():
- _log_deprecation_warning(
- "The `CheckpointConfig` class should be imported from `ray.tune` "
- "when passing it to the Tuner. Please update your imports."
- f"{V2_MIGRATION_GUIDE_MESSAGE}"
- )
- if not isinstance(self.failure_config, FailureConfig):
- if _v2_migration_warnings_enabled():
- _log_deprecation_warning(
- "The `FailureConfig` class should be imported from `ray.tune` "
- "when passing it to the Tuner. Please update your imports."
- f"{V2_MIGRATION_GUIDE_MESSAGE}"
- )
|