__init__.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Try import ray[train] core requirements (defined in setup.py)
  2. # isort: off
  3. try:
  4. import fsspec # noqa: F401
  5. import pandas # noqa: F401
  6. import pyarrow # noqa: F401
  7. import requests # noqa: F401
  8. except ImportError as exc:
  9. raise ImportError(
  10. "Can't import ray.train as some dependencies are missing. "
  11. 'Run `pip install "ray[train]"` to fix.'
  12. ) from exc
  13. # isort: on
  14. from ray.air.config import CheckpointConfig, FailureConfig, RunConfig, ScalingConfig
  15. from ray.air.result import Result
  16. # Import this first so it can be used in other modules
  17. from ray.train._checkpoint import Checkpoint
  18. from ray.train._internal.data_config import DataConfig
  19. from ray.train._internal.session import get_checkpoint, get_dataset_shard, report
  20. from ray.train._internal.syncer import SyncConfig
  21. from ray.train.backend import BackendConfig
  22. from ray.train.base_trainer import TrainingFailedError
  23. from ray.train.constants import TRAIN_DATASET_KEY
  24. from ray.train.context import TrainContext, get_context
  25. from ray.train.v2._internal.constants import is_v2_enabled
  26. if is_v2_enabled():
  27. try:
  28. import pydantic # noqa: F401
  29. except (ImportError, ModuleNotFoundError) as exc:
  30. raise ImportError(
  31. "`ray.train.v2` requires the pydantic package, which is missing. "
  32. "Run the following command to fix this: `pip install pydantic`"
  33. ) from exc
  34. from ray.train.v2.api.callback import UserCallback # noqa: F811
  35. from ray.train.v2.api.config import ( # noqa: F811
  36. CheckpointConfig,
  37. FailureConfig,
  38. RunConfig,
  39. ScalingConfig,
  40. )
  41. from ray.train.v2.api.context import TrainContext # noqa: F811
  42. from ray.train.v2.api.exceptions import ( # noqa: F811
  43. ControllerError,
  44. TrainingFailedError,
  45. WorkerGroupError,
  46. )
  47. from ray.train.v2.api.report_config import ( # noqa: F811
  48. CheckpointConsistencyMode,
  49. CheckpointUploadMode,
  50. )
  51. from ray.train.v2.api.reported_checkpoint import ReportedCheckpoint # noqa: F811
  52. from ray.train.v2.api.result import Result # noqa: F811
  53. from ray.train.v2.api.train_fn_utils import ( # noqa: F811
  54. get_all_reported_checkpoints,
  55. get_checkpoint,
  56. get_context,
  57. get_dataset_shard,
  58. report,
  59. )
  60. from ray.train.v2.api.validation_config import ( # noqa: F811
  61. ValidationConfig,
  62. ValidationFn,
  63. ValidationTaskConfig,
  64. )
  65. __all__ = [
  66. "get_checkpoint",
  67. "get_context",
  68. "get_dataset_shard",
  69. "report",
  70. "BackendConfig",
  71. "Checkpoint",
  72. "CheckpointConfig",
  73. "DataConfig",
  74. "FailureConfig",
  75. "Result",
  76. "RunConfig",
  77. "ScalingConfig",
  78. "SyncConfig",
  79. "TrainContext",
  80. "TrainingFailedError",
  81. "TRAIN_DATASET_KEY",
  82. ]
  83. get_checkpoint.__module__ = "ray.train"
  84. get_context.__module__ = "ray.train"
  85. get_dataset_shard.__module__ = "ray.train"
  86. report.__module__ = "ray.train"
  87. BackendConfig.__module__ = "ray.train"
  88. Checkpoint.__module__ = "ray.train"
  89. CheckpointConfig.__module__ = "ray.train"
  90. DataConfig.__module__ = "ray.train"
  91. FailureConfig.__module__ = "ray.train"
  92. Result.__module__ = "ray.train"
  93. RunConfig.__module__ = "ray.train"
  94. ScalingConfig.__module__ = "ray.train"
  95. SyncConfig.__module__ = "ray.train"
  96. TrainContext.__module__ = "ray.train"
  97. TrainingFailedError.__module__ = "ray.train"
  98. # TODO: consider implementing these in v1 and raising ImportError instead.
  99. if is_v2_enabled():
  100. __all__.extend(
  101. [
  102. "CheckpointUploadMode",
  103. "CheckpointConsistencyMode",
  104. "ControllerError",
  105. "ReportedCheckpoint",
  106. "UserCallback",
  107. "WorkerGroupError",
  108. "ValidationConfig",
  109. "ValidationFn",
  110. "ValidationTaskConfig",
  111. "get_all_reported_checkpoints",
  112. ]
  113. )
  114. CheckpointUploadMode.__module__ = "ray.train"
  115. CheckpointConsistencyMode.__module__ = "ray.train"
  116. ControllerError.__module__ = "ray.train"
  117. ReportedCheckpoint.__module__ = "ray.train"
  118. UserCallback.__module__ = "ray.train"
  119. WorkerGroupError.__module__ = "ray.train"
  120. ValidationConfig.__module__ = "ray.train"
  121. ValidationFn.__module__ = "ray.train"
  122. ValidationTaskConfig.__module__ = "ray.train"
  123. get_all_reported_checkpoints.__module__ = "ray.train"
  124. # DO NOT ADD ANYTHING AFTER THIS LINE.