actor_group.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import logging
  2. import weakref
  3. from dataclasses import dataclass
  4. from typing import Dict, List, Optional, Tuple, Type, TypeVar
  5. import ray
  6. from ray._private.utils import get_ray_doc_version
  7. from ray.actor import ActorHandle
  8. from ray.util.annotations import Deprecated
  9. T = TypeVar("T")
  10. ActorMetadata = TypeVar("ActorMetadata")
  11. logger = logging.getLogger(__name__)
  12. @dataclass
  13. class ActorWrapper:
  14. """Class containing an actor and its metadata."""
  15. actor: ActorHandle
  16. metadata: ActorMetadata
  17. @dataclass
  18. class ActorConfig:
  19. num_cpus: float
  20. num_gpus: float
  21. resources: Optional[Dict[str, float]]
  22. init_args: Tuple
  23. init_kwargs: Dict
  24. class ActorGroupMethod:
  25. def __init__(self, actor_group: "ActorGroup", method_name: str):
  26. self.actor_group = weakref.ref(actor_group)
  27. self._method_name = method_name
  28. def __call__(self, *args, **kwargs):
  29. raise TypeError(
  30. "ActorGroup methods cannot be called directly. "
  31. "Instead "
  32. f"of running 'object.{self._method_name}()', try "
  33. f"'object.{self._method_name}.remote()'."
  34. )
  35. def remote(self, *args, **kwargs):
  36. return [
  37. getattr(a.actor, self._method_name).remote(*args, **kwargs)
  38. for a in self.actor_group().actors
  39. ]
  40. @Deprecated(
  41. message="For stateless/task processing, use ray.util.multiprocessing, see details "
  42. f"in https://docs.ray.io/en/{get_ray_doc_version()}/ray-more-libs/multiprocessing.html. " # noqa: E501
  43. "For stateful/actor processing such as batch prediction, use "
  44. "Datasets.map_batches(compute=ActorPoolStrategy, ...), see details in "
  45. f"https://docs.ray.io/en/{get_ray_doc_version()}/data/api/dataset.html#ray.data.Dataset.map_batches.", # noqa: E501
  46. warning=True,
  47. )
  48. class ActorGroup:
  49. """Group of Ray Actors that can execute arbitrary functions.
  50. ``ActorGroup`` launches Ray actors according to the given
  51. specification. It can then execute arbitrary Python functions in each of
  52. these actors.
  53. If not enough resources are available to launch the actors, the Ray
  54. cluster will automatically scale up if autoscaling is enabled.
  55. Args:
  56. actor_cls: The class to use as the remote actors.
  57. num_actors: The number of the provided Ray actors to
  58. launch. Defaults to 1.
  59. num_cpus_per_actor: The number of CPUs to reserve for each
  60. actor. Fractional values are allowed. Defaults to 1.
  61. num_gpus_per_actor: The number of GPUs to reserve for each
  62. actor. Fractional values are allowed. Defaults to 0.
  63. resources_per_actor (Optional[Dict[str, float]]):
  64. Dictionary specifying the resources that will be
  65. requested for each actor in addition to ``num_cpus_per_actor``
  66. and ``num_gpus_per_actor``.
  67. init_args, init_kwargs: If ``actor_cls`` is provided,
  68. these args will be used for the actor initialization.
  69. """
  70. def __init__(
  71. self,
  72. actor_cls: Type,
  73. num_actors: int = 1,
  74. num_cpus_per_actor: float = 1,
  75. num_gpus_per_actor: float = 0,
  76. resources_per_actor: Optional[Dict[str, float]] = None,
  77. init_args: Optional[Tuple] = None,
  78. init_kwargs: Optional[Dict] = None,
  79. ):
  80. from ray._common.usage.usage_lib import record_library_usage
  81. record_library_usage("util.ActorGroup")
  82. if num_actors <= 0:
  83. raise ValueError(
  84. "The provided `num_actors` must be greater "
  85. f"than 0. Received num_actors={num_actors} "
  86. f"instead."
  87. )
  88. if num_cpus_per_actor < 0 or num_gpus_per_actor < 0:
  89. raise ValueError(
  90. "The number of CPUs and GPUs per actor must "
  91. "not be negative. Received "
  92. f"num_cpus_per_actor={num_cpus_per_actor} and "
  93. f"num_gpus_per_actor={num_gpus_per_actor}."
  94. )
  95. self.actors = []
  96. self.num_actors = num_actors
  97. self.actor_config = ActorConfig(
  98. num_cpus=num_cpus_per_actor,
  99. num_gpus=num_gpus_per_actor,
  100. resources=resources_per_actor,
  101. init_args=init_args or (),
  102. init_kwargs=init_kwargs or {},
  103. )
  104. self._remote_cls = ray.remote(
  105. num_cpus=self.actor_config.num_cpus,
  106. num_gpus=self.actor_config.num_gpus,
  107. resources=self.actor_config.resources,
  108. )(actor_cls)
  109. self.start()
  110. def __getattr__(self, item):
  111. if len(self.actors) == 0:
  112. raise RuntimeError(
  113. "This ActorGroup has been shutdown. Please start it again."
  114. )
  115. # Same implementation as actor.py
  116. return ActorGroupMethod(self, item)
  117. def __len__(self):
  118. return len(self.actors)
  119. def __getitem__(self, item):
  120. return self.actors[item]
  121. def start(self):
  122. """Starts all the actors in this actor group."""
  123. if self.actors and len(self.actors) > 0:
  124. raise RuntimeError(
  125. "The actors have already been started. "
  126. "Please call `shutdown` first if you want to "
  127. "restart them."
  128. )
  129. logger.debug(f"Starting {self.num_actors} actors.")
  130. self.add_actors(self.num_actors)
  131. logger.debug(f"{len(self.actors)} actors have successfully started.")
  132. def shutdown(self, patience_s: float = 5):
  133. """Shutdown all the actors in this actor group.
  134. Args:
  135. patience_s: Attempt a graceful shutdown
  136. of the actors for this many seconds. Fallback to force kill
  137. if graceful shutdown is not complete after this time. If
  138. this is less than or equal to 0, immediately force kill all
  139. actors.
  140. """
  141. logger.debug(f"Shutting down {len(self.actors)} actors.")
  142. if patience_s <= 0:
  143. for actor in self.actors:
  144. ray.kill(actor.actor)
  145. else:
  146. done_refs = [w.actor.__ray_terminate__.remote() for w in self.actors]
  147. # Wait for actors to die gracefully.
  148. done, not_done = ray.wait(done_refs, timeout=patience_s)
  149. if not_done:
  150. logger.debug("Graceful termination failed. Falling back to force kill.")
  151. # If all actors are not able to die gracefully, then kill them.
  152. for actor in self.actors:
  153. ray.kill(actor.actor)
  154. logger.debug("Shutdown successful.")
  155. self.actors = []
  156. def remove_actors(self, actor_indexes: List[int]):
  157. """Removes the actors with the specified indexes.
  158. Args:
  159. actor_indexes (List[int]): The indexes of the actors to remove.
  160. """
  161. new_actors = []
  162. for i in range(len(self.actors)):
  163. if i not in actor_indexes:
  164. new_actors.append(self.actors[i])
  165. self.actors = new_actors
  166. def add_actors(self, num_actors: int):
  167. """Adds ``num_actors`` to this ActorGroup.
  168. Args:
  169. num_actors: The number of actors to add.
  170. """
  171. new_actors = []
  172. new_actor_metadata = []
  173. for _ in range(num_actors):
  174. actor = self._remote_cls.remote(
  175. *self.actor_config.init_args, **self.actor_config.init_kwargs
  176. )
  177. new_actors.append(actor)
  178. if hasattr(actor, "get_actor_metadata"):
  179. new_actor_metadata.append(actor.get_actor_metadata.remote())
  180. # Get metadata from all actors.
  181. metadata = ray.get(new_actor_metadata)
  182. if len(metadata) == 0:
  183. metadata = [None] * len(new_actors)
  184. for i in range(len(new_actors)):
  185. self.actors.append(ActorWrapper(actor=new_actors[i], metadata=metadata[i]))
  186. @property
  187. def actor_metadata(self):
  188. return [a.metadata for a in self.actors]