collectives.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import logging
  2. from typing import Optional, TypeVar
  3. from ray.train.v2._internal.execution.train_fn_utils import get_train_fn_utils
  4. from ray.train.v2._internal.util import requires_train_worker
  5. from ray.util.annotations import PublicAPI
  6. T = TypeVar("T", bound=Optional[object])
  7. logger = logging.getLogger(__file__)
  8. @PublicAPI(stability="alpha")
  9. @requires_train_worker()
  10. def broadcast_from_rank_zero(data: T) -> T:
  11. """Broadcast small (<1kb) data from the rank 0 worker to all other workers.
  12. Serves as a barrier, meaning that all workers must call this method before
  13. the training function can continue.
  14. Example:
  15. .. testcode:
  16. from ray.train import get_context
  17. from ray.train.collective import broadcast_from_rank_zero
  18. from ray.train.torch import TorchTrainer
  19. def train_func():
  20. ...
  21. if get_context().get_world_rank() == 0:
  22. data = {"some_key": "some_value"}
  23. else:
  24. data = None
  25. data = broadcast_from_rank_zero(data)
  26. ...
  27. trainer = TorchTrainer(train_func)
  28. trainer.fit()
  29. Args:
  30. data: The small (1kb) data to broadcast from the rank 0 worker to all
  31. other workers.
  32. Returns:
  33. The data broadcasted from the rank 0 worker.
  34. Raises:
  35. ValueError: If the data is too big.
  36. pickle.PicklingError: If the data is not pickleable.
  37. TypeError: If the data is not pickleable.
  38. """
  39. return get_train_fn_utils().broadcast_from_rank_zero(data)
  40. @PublicAPI(stability="alpha")
  41. @requires_train_worker()
  42. def barrier() -> None:
  43. """Create a barrier across all workers.
  44. All workers must call this method before the training function can continue.
  45. Example:
  46. .. testcode:
  47. from ray.train import get_context
  48. from ray.train.collective import barrier
  49. from ray.train.torch import TorchTrainer
  50. def train_func():
  51. ...
  52. print(f"Rank {get_context().get_world_rank()} is waiting at the barrier.")
  53. barrier()
  54. print(f"Rank {get_context().get_world_rank()} has passed the barrier.")
  55. ...
  56. trainer = TorchTrainer(train_func)
  57. trainer.fit()
  58. """
  59. return get_train_fn_utils().barrier()