communicator.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from dataclasses import dataclass
  2. from typing import List
  3. import ray
  4. from ray.util.collective.types import Backend
  5. @dataclass
  6. class Communicator:
  7. """
  8. A handle to a communicator that we are a member of.
  9. """
  10. # The name of the communicator.
  11. name: str
  12. # Our rank in the collective group.
  13. rank: int
  14. # A valid backend, as defined by
  15. # ray.util.collective.types.Backend.
  16. backend: str
  17. class CommunicatorHandle:
  18. """
  19. A communicator handle used by the driver to store handles to the
  20. actors in the communicator.
  21. """
  22. def __init__(self, actors: List[ray.actor.ActorHandle], name: str, backend: str):
  23. """
  24. Initializes the CommunicatorHandle with the given actor handles.
  25. Assumes that the communicator has already been initialized on all actors.
  26. Args:
  27. actors: A list of actor handles to be stored.
  28. name: Name of the communicator.
  29. backend: Communicator backend. See
  30. ray.util.collective.types for valid values.
  31. """
  32. self._actors = actors
  33. self._name = name
  34. self._backend = Backend(backend)
  35. def get_rank(self, actor: ray.actor.ActorHandle):
  36. for i, a in enumerate(self._actors):
  37. if a == actor:
  38. return i
  39. return -1
  40. @property
  41. def actors(self) -> List[ray.actor.ActorHandle]:
  42. """
  43. Return all actor handles in this communicator.
  44. """
  45. return self._actors[:]
  46. @property
  47. def name(self) -> str:
  48. return self._name
  49. @property
  50. def backend(self) -> str:
  51. return self._backend