backend.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import logging
  2. from contextlib import nullcontext
  3. from typing import TypeVar
  4. from ray.train._internal.base_worker_group import BaseWorkerGroup
  5. from ray.train._internal.utils import Singleton
  6. from ray.util.annotations import DeveloperAPI
  7. from ray.widgets import make_table_html_repr
  8. EncodedData = TypeVar("EncodedData")
  9. logger = logging.getLogger(__name__)
  10. @DeveloperAPI
  11. class BackendConfig:
  12. """Parent class for configurations of training backend."""
  13. @property
  14. def backend_cls(self):
  15. return Backend
  16. @property
  17. def train_func_context(self):
  18. return nullcontext
  19. def _repr_html_(self) -> str:
  20. return make_table_html_repr(obj=self, title=type(self).__name__)
  21. @DeveloperAPI
  22. class Backend(metaclass=Singleton):
  23. """Singleton for distributed communication backend.
  24. Attributes:
  25. share_cuda_visible_devices: If True, each worker
  26. process will have CUDA_VISIBLE_DEVICES set as the visible device
  27. IDs of all workers on the same node for this training instance.
  28. If False, each worker will have CUDA_VISIBLE_DEVICES set to the
  29. device IDs allocated by Ray for that worker.
  30. """
  31. share_cuda_visible_devices: bool = False
  32. def on_start(self, worker_group: BaseWorkerGroup, backend_config: BackendConfig):
  33. """Logic for starting this backend."""
  34. pass
  35. def on_shutdown(self, worker_group: BaseWorkerGroup, backend_config: BackendConfig):
  36. """Logic for shutting down the backend."""
  37. pass
  38. def on_training_start(
  39. self, worker_group: BaseWorkerGroup, backend_config: BackendConfig
  40. ):
  41. """Logic ran right before training is started.
  42. Session API is available at this point."""
  43. pass