util.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. """Some utility class for Collectives."""
  2. import asyncio
  3. import logging
  4. import ray
  5. logger = logging.getLogger(__name__)
  6. @ray.remote
  7. class NCCLUniqueIDStore:
  8. """NCCLUniqueID Store as a named actor class.
  9. Args:
  10. name: the unique name for this named actor.
  11. Attributes:
  12. name: the unique name for this named actor.
  13. nccl_id: the NCCLUniqueID held in this store.
  14. """
  15. def __init__(self, name):
  16. self.name = name
  17. self.nccl_id = None
  18. self.event = asyncio.Event()
  19. async def set_id(self, uid):
  20. """
  21. Initialize the NCCL unique ID for this store.
  22. Args:
  23. uid: the unique ID generated via the NCCL generate_communicator_id API.
  24. Returns:
  25. The NCCL unique ID set.
  26. """
  27. self.nccl_id = uid
  28. self.event.set()
  29. return uid
  30. async def wait_and_get_id(self):
  31. """Wait for the NCCL unique ID to be set and return it."""
  32. await self.event.wait()
  33. return self.nccl_id
  34. def get_id(self):
  35. """Get the NCCL unique ID held in this store."""
  36. if not self.nccl_id:
  37. logger.warning(
  38. "The NCCL ID has not been set yet for store {}.".format(self.name)
  39. )
  40. return self.nccl_id
  41. @ray.remote
  42. class Info:
  43. """Store the group information created via `create_collective_group`.
  44. Note: Should be used as a NamedActor.
  45. """
  46. def __init__(self):
  47. self.ids = None
  48. self.world_size = -1
  49. self.rank = -1
  50. self.backend = None
  51. self.gloo_timeout = 30000
  52. def set_info(self, ids, world_size, rank, backend, gloo_timeout):
  53. """Store collective information."""
  54. self.ids = ids
  55. self.world_size = world_size
  56. self.rank = rank
  57. self.backend = backend
  58. self.gloo_timeout = gloo_timeout
  59. def get_info(self):
  60. """Get previously stored collective information."""
  61. return (
  62. self.ids,
  63. self.world_size,
  64. self.rank,
  65. self.backend,
  66. self.gloo_timeout,
  67. )