| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 |
- # Try import ray[train] core requirements (defined in setup.py)
- # isort: off
- try:
- import fsspec # noqa: F401
- import pandas # noqa: F401
- import pyarrow # noqa: F401
- import requests # noqa: F401
- except ImportError as exc:
- raise ImportError(
- "Can't import ray.train as some dependencies are missing. "
- 'Run `pip install "ray[train]"` to fix.'
- ) from exc
- # isort: on
- from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
- from ray.air.result import Result
- # Import this first so it can be used in other modules
- from ray.train._checkpoint import Checkpoint
- from ray.train._internal.data_config import DataConfig
- from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
- from ray.train._internal.syncer import SyncConfig
- from ray.train.backend import BackendConfig
- from ray.train.base_trainer import TrainingFailedError
- from ray.train.constants import TRAIN_DATASET_KEY
- from ray.train.context import TrainContext, get_context
- from ray.train.v2._internal.constants import is_v2_enabled
- if is_v2_enabled():
- try:
- import pydantic # noqa: F401
- except (ImportError, ModuleNotFoundError) as exc:
- raise ImportError(
- "`ray.train.v2` requires the pydantic package, which is missing. "
- "Run the following command to fix this: `pip install pydantic`"
- ) from exc
- from ray.train.v2.api.callback import UserCallback # noqa: F811
- from ray.train.v2.api.config import ( # noqa: F811
- CheckpointConfig,
- FailureConfig,
- RunConfig,
- ScalingConfig,
- )
- from ray.train.v2.api.context import TrainContext # noqa: F811
- from ray.train.v2.api.exceptions import ( # noqa: F811
- ControllerError,
- TrainingFailedError,
- WorkerGroupError,
- )
- from ray.train.v2.api.report_config import ( # noqa: F811
- CheckpointConsistencyMode,
- CheckpointUploadMode,
- )
- from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint # noqa: F811
- from ray.train.v2.api.result import Result # noqa: F811
- from ray.train.v2.api.train_fn_utils import ( # noqa: F811
- get_all_reported_checkpoints,
- get_checkpoint,
- get_context,
- get_dataset_shard,
- report,
- )
- from ray.train.v2.api.validation_config import ( # noqa: F811
- ValidationConfig,
- ValidationFn,
- ValidationTaskConfig,
- )
- __all__ = [
- "get_checkpoint",
- "get_context",
- "get_dataset_shard",
- "report",
- "BackendConfig",
- "Checkpoint",
- "CheckpointConfig",
- "DataConfig",
- "FailureConfig",
- "Result",
- "RunConfig",
- "ScalingConfig",
- "SyncConfig",
- "TrainContext",
- "TrainingFailedError",
- "TRAIN_DATASET_KEY",
- ]
- get_checkpoint.__module__ = "ray.train"
- get_context.__module__ = "ray.train"
- get_dataset_shard.__module__ = "ray.train"
- report.__module__ = "ray.train"
- BackendConfig.__module__ = "ray.train"
- Checkpoint.__module__ = "ray.train"
- CheckpointConfig.__module__ = "ray.train"
- DataConfig.__module__ = "ray.train"
- FailureConfig.__module__ = "ray.train"
- Result.__module__ = "ray.train"
- RunConfig.__module__ = "ray.train"
- ScalingConfig.__module__ = "ray.train"
- SyncConfig.__module__ = "ray.train"
- TrainContext.__module__ = "ray.train"
- TrainingFailedError.__module__ = "ray.train"
- # TODO: consider implementing these in v1 and raising ImportError instead.
- if is_v2_enabled():
- __all__.extend(
- [
- "CheckpointUploadMode",
- "CheckpointConsistencyMode",
- "ControllerError",
- "ReportedCheckpoint",
- "UserCallback",
- "WorkerGroupError",
- "ValidationConfig",
- "ValidationFn",
- "ValidationTaskConfig",
- "get_all_reported_checkpoints",
- ]
- )
- CheckpointUploadMode.__module__ = "ray.train"
- CheckpointConsistencyMode.__module__ = "ray.train"
- ControllerError.__module__ = "ray.train"
- ReportedCheckpoint.__module__ = "ray.train"
- UserCallback.__module__ = "ray.train"
- WorkerGroupError.__module__ = "ray.train"
- ValidationConfig.__module__ = "ray.train"
- ValidationFn.__module__ = "ray.train"
- ValidationTaskConfig.__module__ = "ray.train"
- get_all_reported_checkpoints.__module__ = "ray.train"
- # DO NOT ADD ANYTHING AFTER THIS LINE.
|