| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485 |
- """Some utility class for Collectives."""
- import asyncio
- import logging
- import ray
- logger = logging.getLogger(__name__)
- @ray.remote
- class NCCLUniqueIDStore:
- """NCCLUniqueID Store as a named actor class.
- Args:
- name: the unique name for this named actor.
- Attributes:
- name: the unique name for this named actor.
- nccl_id: the NCCLUniqueID held in this store.
- """
- def __init__(self, name):
- self.name = name
- self.nccl_id = None
- self.event = asyncio.Event()
- async def set_id(self, uid):
- """
- Initialize the NCCL unique ID for this store.
- Args:
- uid: the unique ID generated via the NCCL generate_communicator_id API.
- Returns:
- The NCCL unique ID set.
- """
- self.nccl_id = uid
- self.event.set()
- return uid
- async def wait_and_get_id(self):
- """Wait for the NCCL unique ID to be set and return it."""
- await self.event.wait()
- return self.nccl_id
- def get_id(self):
- """Get the NCCL unique ID held in this store."""
- if not self.nccl_id:
- logger.warning(
- "The NCCL ID has not been set yet for store {}.".format(self.name)
- )
- return self.nccl_id
- @ray.remote
- class Info:
- """Store the group information created via `create_collective_group`.
- Note: Should be used as a NamedActor.
- """
- def __init__(self):
- self.ids = None
- self.world_size = -1
- self.rank = -1
- self.backend = None
- self.gloo_timeout = 30000
- def set_info(self, ids, world_size, rank, backend, gloo_timeout):
- """Store collective information."""
- self.ids = ids
- self.world_size = world_size
- self.rank = rank
- self.backend = backend
- self.gloo_timeout = gloo_timeout
- def get_info(self):
- """Get previously stored collective information."""
- return (
- self.ids,
- self.world_size,
- self.rank,
- self.backend,
- self.gloo_timeout,
- )
|