actor_manager.py 45 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103
  1. import copy
  2. import logging
  3. import sys
  4. import time
  5. from collections import defaultdict
  6. from dataclasses import dataclass, field
  7. from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union
  8. import ray
  9. from ray.actor import ActorHandle
  10. from ray.exceptions import RayError, RayTaskError
  11. from ray.rllib.utils.typing import T
  12. from ray.util.annotations import DeveloperAPI
  13. logger = logging.getLogger(__name__)
  14. @DeveloperAPI
  15. class ResultOrError:
  16. """A wrapper around a result or a RayError thrown during remote task/actor calls.
  17. This is used to return data from `FaultTolerantActorManager` that allows us to
  18. distinguish between RayErrors (remote actor related) and valid results.
  19. """
  20. def __init__(self, result: Any = None, error: Exception = None):
  21. """One and only one of result or error should be set.
  22. Args:
  23. result: The result of the computation. Note that None is a valid result if
  24. the remote function does not return anything.
  25. error: Alternatively, the error that occurred during the computation.
  26. """
  27. self._result = result
  28. self._error = (
  29. # Easier to handle if we show the user the original error.
  30. error.as_instanceof_cause()
  31. if isinstance(error, RayTaskError)
  32. else error
  33. )
  34. @property
  35. def ok(self):
  36. return self._error is None
  37. def get(self):
  38. """Returns the result or the error."""
  39. if self._error:
  40. return self._error
  41. else:
  42. return self._result
  43. @DeveloperAPI
  44. @dataclass
  45. class CallResult:
  46. """Represents a single result from a call to an actor.
  47. Each CallResult contains the index of the actor that was called
  48. plus the result or error from the call.
  49. """
  50. actor_id: int
  51. result_or_error: ResultOrError
  52. tag: str
  53. @property
  54. def ok(self):
  55. """Passes through the ok property from the result_or_error."""
  56. return self.result_or_error.ok
  57. def get(self):
  58. """Passes through the get method from the result_or_error."""
  59. return self.result_or_error.get()
  60. @DeveloperAPI
  61. class RemoteCallResults:
  62. """Represents a list of results from calls to a set of actors.
  63. CallResults provides convenient APIs to iterate over the results
  64. while skipping errors, etc.
  65. .. testcode::
  66. :skipif: True
  67. manager = FaultTolerantActorManager(
  68. actors, max_remote_requests_in_flight_per_actor=2,
  69. )
  70. results = manager.foreach_actor(lambda w: w.call())
  71. # Iterate over all results ignoring errors.
  72. for result in results.ignore_errors():
  73. print(result.get())
  74. """
  75. class _Iterator:
  76. """An iterator over the results of a remote call."""
  77. def __init__(self, call_results: List[CallResult]):
  78. self._call_results = call_results
  79. def __iter__(self) -> Iterator[CallResult]:
  80. return self
  81. def __next__(self) -> CallResult:
  82. if not self._call_results:
  83. raise StopIteration
  84. return self._call_results.pop(0)
  85. def __init__(self):
  86. self.result_or_errors: List[CallResult] = []
  87. def add_result(self, actor_id: int, result_or_error: ResultOrError, tag: str):
  88. """Add index of a remote actor plus the call result to the list.
  89. Args:
  90. actor_id: ID of the remote actor.
  91. result_or_error: The result or error from the call.
  92. tag: A description to identify the call.
  93. """
  94. self.result_or_errors.append(CallResult(actor_id, result_or_error, tag))
  95. def __iter__(self) -> Iterator[ResultOrError]:
  96. """Return an iterator over the results."""
  97. # Shallow copy the list.
  98. return self._Iterator(copy.copy(self.result_or_errors))
  99. def __len__(self) -> int:
  100. return len(self.result_or_errors)
  101. def ignore_errors(self) -> Iterator[ResultOrError]:
  102. """Return an iterator over the results, skipping all errors."""
  103. return self._Iterator([r for r in self.result_or_errors if r.ok])
  104. def ignore_ray_errors(self) -> Iterator[ResultOrError]:
  105. """Return an iterator over the results, skipping only Ray errors.
  106. Similar to ignore_errors, but only skips Errors raised because of
  107. remote actor problems (often get restored automatcially).
  108. This is useful for callers that want to handle application errors differently
  109. from Ray errors.
  110. """
  111. return self._Iterator(
  112. [r for r in self.result_or_errors if not isinstance(r.get(), RayError)]
  113. )
  114. @DeveloperAPI
  115. class FaultAwareApply:
  116. @DeveloperAPI
  117. def ping(self) -> str:
  118. """Ping the actor. Can be used as a health check.
  119. Returns:
  120. "pong" if actor is up and well.
  121. """
  122. return "pong"
  123. @DeveloperAPI
  124. def apply(
  125. self,
  126. func: Callable[[Any, Optional[Any], Optional[Any]], T],
  127. *args,
  128. **kwargs,
  129. ) -> T:
  130. """Calls the given function with this Actor instance.
  131. A generic interface for applying arbitrary member functions on a
  132. remote actor.
  133. Args:
  134. func: The function to call, with this actor as first
  135. argument, followed by args, and kwargs.
  136. args: Optional additional args to pass to the function call.
  137. kwargs: Optional additional kwargs to pass to the function call.
  138. Returns:
  139. The return value of the function call.
  140. """
  141. try:
  142. return func(self, *args, **kwargs)
  143. except Exception as e:
  144. # Actor should be recreated by Ray.
  145. if self.config.restart_failed_env_runners:
  146. logger.exception(f"Worker exception caught during `apply()`: {e}")
  147. # Small delay to allow logs messages to propagate.
  148. time.sleep(self.config.delay_between_env_runner_restarts_s)
  149. # Kill this worker so Ray Core can restart it.
  150. sys.exit(1)
  151. # Actor should be left dead.
  152. else:
  153. raise e
  154. @DeveloperAPI
  155. class FaultTolerantActorManager:
  156. """A manager that is aware of the healthiness of remote actors.
  157. .. testcode::
  158. import time
  159. import ray
  160. from ray.rllib.utils.actor_manager import FaultTolerantActorManager
  161. @ray.remote
  162. class MyActor:
  163. def apply(self, func):
  164. return func(self)
  165. def do_something(self):
  166. return True
  167. actors = [MyActor.remote() for _ in range(3)]
  168. manager = FaultTolerantActorManager(
  169. actors, max_remote_requests_in_flight_per_actor=2,
  170. )
  171. # Synchronous remote calls.
  172. results = manager.foreach_actor(lambda actor: actor.do_something())
  173. # Print results ignoring returned errors.
  174. print([r.get() for r in results.ignore_errors()])
  175. # Asynchronous remote calls.
  176. manager.foreach_actor_async(lambda actor: actor.do_something())
  177. time.sleep(2) # Wait for the tasks to finish.
  178. for r in manager.fetch_ready_async_reqs():
  179. # Handle result and errors.
  180. if r.ok:
  181. print(r.get())
  182. else:
  183. print("Error: {}".format(r.get()))
  184. """
  185. @dataclass
  186. class _ActorState:
  187. """State of a single actor."""
  188. # Num of outstanding async requests for this actor by tag.
  189. num_in_flight_async_requests_by_tag: Dict[Optional[str], int] = field(
  190. default_factory=dict
  191. )
  192. # Whether this actor is in a healthy state.
  193. is_healthy: bool = True
  194. def get_num_in_flight_requests(self, tag: Optional[str] = None) -> int:
  195. """Get number of in-flight requests for a specific tag or all tags."""
  196. if tag is None:
  197. return sum(self.num_in_flight_async_requests_by_tag.values())
  198. return self.num_in_flight_async_requests_by_tag.get(tag, 0)
  199. def increment_requests(self, tag: Optional[str] = None) -> None:
  200. """Increment the count of in-flight requests for a tag."""
  201. if tag not in self.num_in_flight_async_requests_by_tag:
  202. self.num_in_flight_async_requests_by_tag[tag] = 0
  203. self.num_in_flight_async_requests_by_tag[tag] += 1
  204. def decrement_requests(self, tag: Optional[str] = None) -> None:
  205. """Decrement the count of in-flight requests for a tag."""
  206. if tag in self.num_in_flight_async_requests_by_tag:
  207. self.num_in_flight_async_requests_by_tag[tag] -= 1
  208. if self.num_in_flight_async_requests_by_tag[tag] <= 0:
  209. del self.num_in_flight_async_requests_by_tag[tag]
  210. def __init__(
  211. self,
  212. actors: Optional[List[ActorHandle]] = None,
  213. max_remote_requests_in_flight_per_actor: int = 2,
  214. init_id: int = 0,
  215. ):
  216. """Construct a FaultTolerantActorManager.
  217. Args:
  218. actors: A list of ray remote actors to manage on. These actors must have an
  219. ``apply`` method which takes a function with only one parameter (the
  220. actor instance itself).
  221. max_remote_requests_in_flight_per_actor: The maximum number of remote
  222. requests that can be in flight per actor. Any requests made to the pool
  223. that cannot be scheduled because the limit has been reached will be
  224. dropped. This only applies to the asynchronous remote call mode.
  225. init_id: The initial ID to use for the next remote actor. Default is 0.
  226. """
  227. # For round-robin style async requests, keep track of which actor to send
  228. # a new func next (current).
  229. self._next_id = self._current_actor_id = init_id
  230. # Actors are stored in a map and indexed by a unique (int) ID.
  231. self._actors: Dict[int, ActorHandle] = {}
  232. self._remote_actor_states: Dict[int, self._ActorState] = {}
  233. self._restored_actors = set()
  234. self.add_actors(actors or [])
  235. # Maps outstanding async requests to the IDs of the actor IDs that
  236. # are executing them.
  237. self._in_flight_req_to_actor_id: Dict[ray.ObjectRef, int] = {}
  238. self._max_remote_requests_in_flight_per_actor = (
  239. max_remote_requests_in_flight_per_actor
  240. )
  241. # Useful metric.
  242. self._num_actor_restarts = 0
  243. @DeveloperAPI
  244. def actor_ids(self) -> List[int]:
  245. """Returns a list of all worker IDs (healthy or not)."""
  246. return list(self._actors.keys())
  247. @DeveloperAPI
  248. def healthy_actor_ids(self) -> List[int]:
  249. """Returns a list of worker IDs that are healthy."""
  250. return [k for k, v in self._remote_actor_states.items() if v.is_healthy]
  251. @DeveloperAPI
  252. def add_actors(self, actors: List[ActorHandle]):
  253. """Add a list of actors to the pool.
  254. Args:
  255. actors: A list of ray remote actors to be added to the pool.
  256. """
  257. for actor in actors:
  258. self._actors[self._next_id] = actor
  259. self._remote_actor_states[self._next_id] = self._ActorState()
  260. self._next_id += 1
  261. @DeveloperAPI
  262. def remove_actor(self, actor_id: int) -> ActorHandle:
  263. """Remove an actor from the pool.
  264. Args:
  265. actor_id: ID of the actor to remove.
  266. Returns:
  267. Handle to the actor that was removed.
  268. """
  269. actor = self._actors[actor_id]
  270. # Remove the actor from the pool.
  271. del self._actors[actor_id]
  272. del self._remote_actor_states[actor_id]
  273. self._restored_actors.discard(actor_id)
  274. self._remove_async_state(actor_id)
  275. return actor
  276. @DeveloperAPI
  277. def num_actors(self) -> int:
  278. """Return the total number of actors in the pool."""
  279. return len(self._actors)
  280. @DeveloperAPI
  281. def num_healthy_actors(self) -> int:
  282. """Return the number of healthy remote actors."""
  283. return sum(s.is_healthy for s in self._remote_actor_states.values())
  284. @DeveloperAPI
  285. def total_num_restarts(self) -> int:
  286. """Return the number of remote actors that have been restarted."""
  287. return self._num_actor_restarts
  288. @DeveloperAPI
  289. def num_outstanding_async_reqs(self, tag: Optional[str] = None) -> int:
  290. """Return the number of outstanding async requests."""
  291. return sum(
  292. s.get_num_in_flight_requests(tag)
  293. for s in self._remote_actor_states.values()
  294. )
  295. @DeveloperAPI
  296. def is_actor_healthy(self, actor_id: int) -> bool:
  297. """Whether a remote actor is in healthy state.
  298. Args:
  299. actor_id: ID of the remote actor.
  300. Returns:
  301. True if the actor is healthy, False otherwise.
  302. """
  303. if actor_id not in self._remote_actor_states:
  304. raise ValueError(f"Unknown actor id: {actor_id}")
  305. return self._remote_actor_states[actor_id].is_healthy
  306. @DeveloperAPI
  307. def set_actor_state(self, actor_id: int, healthy: bool) -> None:
  308. """Update activate state for a specific remote actor.
  309. Args:
  310. actor_id: ID of the remote actor.
  311. healthy: Whether the remote actor is healthy.
  312. """
  313. if actor_id not in self._remote_actor_states:
  314. raise ValueError(f"Unknown actor id: {actor_id}")
  315. was_healthy = self._remote_actor_states[actor_id].is_healthy
  316. # Set from unhealthy to healthy -> Add to restored set.
  317. if not was_healthy and healthy:
  318. self._restored_actors.add(actor_id)
  319. # Set from healthy to unhealthy -> Remove from restored set.
  320. elif was_healthy and not healthy:
  321. self._restored_actors.discard(actor_id)
  322. self._remote_actor_states[actor_id].is_healthy = healthy
  323. if not healthy:
  324. # Remove any async states.
  325. self._remove_async_state(actor_id)
  326. @DeveloperAPI
  327. def clear(self):
  328. """Clean up managed actors."""
  329. for actor in self._actors.values():
  330. ray.kill(actor)
  331. self._actors.clear()
  332. self._remote_actor_states.clear()
  333. self._restored_actors.clear()
  334. self._in_flight_req_to_actor_id.clear()
  335. @DeveloperAPI
  336. def foreach_actor(
  337. self,
  338. func: Union[Callable[[Any], Any], List[Callable[[Any], Any]], str, List[str]],
  339. *,
  340. kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
  341. healthy_only: bool = True,
  342. remote_actor_ids: Optional[List[int]] = None,
  343. timeout_seconds: Optional[float] = None,
  344. return_obj_refs: bool = False,
  345. mark_healthy: bool = False,
  346. ) -> RemoteCallResults:
  347. """Calls the given function with each actor instance as arg.
  348. Automatically marks actors unhealthy if they crash during the remote call.
  349. Args:
  350. func: A single Callable applied to all specified remote actors or a list
  351. of Callables, that get applied on the list of specified remote actors.
  352. In the latter case, both list of Callables and list of specified actors
  353. must have the same length. Alternatively, you can use the name of the
  354. remote method to be called, instead, or a list of remote method names.
  355. kwargs: An optional single kwargs dict or a list of kwargs dict matching the
  356. list of provided `func` or `remote_actor_ids`. In the first case (single
  357. dict), use `kwargs` on all remote calls. The latter case (list of
  358. dicts) allows you to define individualized kwarg dicts per actor.
  359. healthy_only: If True, applies `func` only to actors currently tagged
  360. "healthy", otherwise to all actors. If `healthy_only=False` and
  361. `mark_healthy=True`, will send `func` to all actors and mark those
  362. actors "healthy" that respond to the request within `timeout_seconds`
  363. and are currently tagged as "unhealthy".
  364. remote_actor_ids: Apply func on a selected set of remote actors. Use None
  365. (default) for all actors.
  366. timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for
  367. fire-and-forget. Set this to None (default) to wait infinitely (i.e. for
  368. synchronous execution).
  369. return_obj_refs: whether to return ObjectRef instead of actual results.
  370. Note, for fault tolerance reasons, these returned ObjectRefs should
  371. never be resolved with ray.get() outside of the context of this manager.
  372. mark_healthy: Whether to mark all those actors healthy again that are
  373. currently marked unhealthy AND that returned results from the remote
  374. call (within the given `timeout_seconds`).
  375. Note that actors are NOT set unhealthy, if they simply time out
  376. (only if they return a RayActorError).
  377. Also not that this setting is ignored if `healthy_only=True` (b/c this
  378. setting only affects actors that are currently tagged as unhealthy).
  379. Returns:
  380. The list of return values of all calls to `func(actor)`. The values may be
  381. actual data returned or exceptions raised during the remote call in the
  382. format of RemoteCallResults.
  383. """
  384. remote_actor_ids = remote_actor_ids or self.actor_ids()
  385. if healthy_only:
  386. func, kwargs, remote_actor_ids = self._filter_by_healthy_state(
  387. func=func, kwargs=kwargs, remote_actor_ids=remote_actor_ids
  388. )
  389. # Send out remote requests.
  390. remote_calls = self._call_actors(
  391. func=func,
  392. kwargs=kwargs,
  393. remote_actor_ids=remote_actor_ids,
  394. )
  395. # Collect remote request results (if available given timeout and/or errors).
  396. _, remote_results = self._fetch_result(
  397. remote_actor_ids=remote_actor_ids,
  398. remote_calls=remote_calls,
  399. tags=[None] * len(remote_calls),
  400. timeout_seconds=timeout_seconds,
  401. return_obj_refs=return_obj_refs,
  402. mark_healthy=mark_healthy,
  403. )
  404. return remote_results
  405. @DeveloperAPI
  406. def foreach_actor_async(
  407. self,
  408. func: Union[Callable[[Any], Any], List[Callable[[Any], Any]], str, List[str]],
  409. tag: Optional[str] = None,
  410. *,
  411. kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
  412. healthy_only: bool = True,
  413. remote_actor_ids: Optional[List[int]] = None,
  414. ) -> int:
  415. """Calls given functions against each actors without waiting for results.
  416. Args:
  417. func: A single Callable applied to all specified remote actors or a list
  418. of Callables, that get applied on the list of specified remote actors.
  419. In the latter case, both list of Callables and list of specified actors
  420. must have the same length. Alternatively, you can use the name of the
  421. remote method to be called, instead, or a list of remote method names.
  422. tag: A tag to identify the results from this async call.
  423. kwargs: An optional single kwargs dict or a list of kwargs dict matching the
  424. list of provided `func` or `remote_actor_ids`. In the first case (single
  425. dict), use `kwargs` on all remote calls. The latter case (list of
  426. dicts) allows you to define individualized kwarg dicts per actor.
  427. healthy_only: If True, applies `func` only to actors currently tagged
  428. "healthy", otherwise to all actors. If `healthy_only=False` and
  429. later, `self.fetch_ready_async_reqs()` is called with
  430. `mark_healthy=True`, will send `func` to all actors and mark those
  431. actors "healthy" that respond to the request within `timeout_seconds`
  432. and are currently tagged as "unhealthy".
  433. remote_actor_ids: Apply func on a selected set of remote actors.
  434. Note, for fault tolerance reasons, these returned ObjectRefs should
  435. never be resolved with ray.get() outside of the context of this manager.
  436. Returns:
  437. The number of async requests that are actually fired.
  438. """
  439. # TODO(avnishn, jungong): so thinking about this a bit more, it would be the
  440. # best if we can attach multiple tags to an async all, like basically this
  441. # parameter should be tags:
  442. # For sync calls, tags would be ().
  443. # For async call users, they can attached multiple tags for a single call, like
  444. # ("rollout_worker", "sync_weight").
  445. # For async fetch result, we can also specify a single, or list of tags. For
  446. # example, ("eval", "sample") will fetch all the sample() calls on eval
  447. # workers.
  448. if not remote_actor_ids:
  449. remote_actor_ids = self.actor_ids()
  450. num_calls = (
  451. len(func)
  452. if isinstance(func, list)
  453. else len(kwargs)
  454. if isinstance(kwargs, list)
  455. else len(remote_actor_ids)
  456. )
  457. # Perform round-robin assignment of all provided calls for any number of our
  458. # actors. Note that this way, some actors might receive more than 1 request in
  459. # this call.
  460. if num_calls != len(remote_actor_ids):
  461. remote_actor_ids = [
  462. (self._current_actor_id + i) % self.num_actors()
  463. for i in range(num_calls)
  464. ]
  465. # Update our round-robin pointer.
  466. self._current_actor_id += num_calls
  467. self._current_actor_id %= self.num_actors()
  468. if healthy_only:
  469. func, kwargs, remote_actor_ids = self._filter_by_healthy_state(
  470. func=func, kwargs=kwargs, remote_actor_ids=remote_actor_ids
  471. )
  472. num_calls_to_make: Dict[int, int] = defaultdict(lambda: 0)
  473. # Drop calls to actors that are too busy for this specific tag.
  474. if isinstance(func, list):
  475. assert len(func) == len(remote_actor_ids)
  476. limited_func = []
  477. limited_kwargs = []
  478. limited_remote_actor_ids = []
  479. for i, (f, raid) in enumerate(zip(func, remote_actor_ids)):
  480. num_outstanding_reqs_for_tag = self._remote_actor_states[
  481. raid
  482. ].get_num_in_flight_requests(tag)
  483. if (
  484. num_outstanding_reqs_for_tag + num_calls_to_make[raid]
  485. < self._max_remote_requests_in_flight_per_actor
  486. ):
  487. num_calls_to_make[raid] += 1
  488. k = kwargs[i] if isinstance(kwargs, list) else (kwargs or {})
  489. limited_func.append(f)
  490. limited_kwargs.append(k)
  491. limited_remote_actor_ids.append(raid)
  492. else:
  493. limited_func = func
  494. limited_kwargs = kwargs
  495. limited_remote_actor_ids = []
  496. for raid in remote_actor_ids:
  497. num_outstanding_reqs_for_tag = self._remote_actor_states[
  498. raid
  499. ].get_num_in_flight_requests(tag)
  500. if (
  501. num_outstanding_reqs_for_tag + num_calls_to_make[raid]
  502. < self._max_remote_requests_in_flight_per_actor
  503. ):
  504. num_calls_to_make[raid] += 1
  505. limited_remote_actor_ids.append(raid)
  506. if not limited_remote_actor_ids:
  507. return 0
  508. remote_calls = self._call_actors(
  509. func=limited_func,
  510. kwargs=limited_kwargs,
  511. remote_actor_ids=limited_remote_actor_ids,
  512. )
  513. # Save these as outstanding requests.
  514. for id, call in zip(limited_remote_actor_ids, remote_calls):
  515. self._remote_actor_states[id].increment_requests(tag)
  516. self._in_flight_req_to_actor_id[call] = (tag, id)
  517. return len(remote_calls)
  518. @DeveloperAPI
  519. def fetch_ready_async_reqs(
  520. self,
  521. *,
  522. tags: Union[str, List[str], Tuple[str, ...]] = (),
  523. timeout_seconds: Optional[float] = 0.0,
  524. return_obj_refs: bool = False,
  525. mark_healthy: bool = False,
  526. ) -> RemoteCallResults:
  527. """Get results from outstanding async requests that are ready.
  528. Automatically mark actors unhealthy if they fail to respond.
  529. Note: If tags is an empty tuple then results from all ready async requests are
  530. returned.
  531. Args:
  532. timeout_seconds: ray.get() timeout. Default is 0, which only fetched those
  533. results (immediately) that are already ready.
  534. tags: A tag or a list of tags to identify the results from this async call.
  535. return_obj_refs: Whether to return ObjectRef instead of actual results.
  536. mark_healthy: Whether to mark all those actors healthy again that are
  537. currently marked unhealthy AND that returned results from the remote
  538. call (within the given `timeout_seconds`).
  539. Note that actors are NOT set to unhealthy, if they simply time out,
  540. meaning take a longer time to fulfil the remote request. We only ever
  541. mark an actor unhealthy, if they raise a RayActorError inside the remote
  542. request.
  543. Also note that this settings is ignored if the preceding
  544. `foreach_actor_async()` call used the `healthy_only=True` argument (b/c
  545. `mark_healthy` only affects actors that are currently tagged as
  546. unhealthy).
  547. Returns:
  548. A list of return values of all calls to `func(actor)` that are ready.
  549. The values may be actual data returned or exceptions raised during the
  550. remote call in the format of RemoteCallResults.
  551. """
  552. # Construct the list of in-flight requests filtered by tag.
  553. remote_calls, remote_actor_ids, valid_tags = self._filter_calls_by_tag(tags)
  554. ready, remote_results = self._fetch_result(
  555. remote_actor_ids=remote_actor_ids,
  556. remote_calls=remote_calls,
  557. tags=valid_tags,
  558. timeout_seconds=timeout_seconds,
  559. return_obj_refs=return_obj_refs,
  560. mark_healthy=mark_healthy,
  561. )
  562. for obj_ref, result in zip(ready, remote_results):
  563. # Get the tag for this request and decrease outstanding request count by 1.
  564. if obj_ref in self._in_flight_req_to_actor_id:
  565. tag, actor_id = self._in_flight_req_to_actor_id[obj_ref]
  566. self._remote_actor_states[result.actor_id].decrement_requests(tag)
  567. # Remove this call from the in-flight list.
  568. del self._in_flight_req_to_actor_id[obj_ref]
  569. return remote_results
  570. @DeveloperAPI
  571. def foreach_actor_async_fetch_ready(
  572. self,
  573. func: Union[Callable[[Any], Any], List[Callable[[Any], Any]], str, List[str]],
  574. tag: Optional[str] = None,
  575. *,
  576. kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
  577. timeout_seconds: Optional[float] = 0.0,
  578. return_obj_refs: bool = False,
  579. mark_healthy: bool = False,
  580. healthy_only: bool = True,
  581. remote_actor_ids: Optional[List[int]] = None,
  582. ignore_ray_errors: bool = True,
  583. return_actor_ids: bool = False,
  584. ) -> List[Union[Tuple[int, Any], Any]]:
  585. """Calls the given function asynchronously and returns previous results if any.
  586. This is a convenience function that calls `fetch_ready_async_reqs()` to get
  587. previous results and then `foreach_actor_async()` to start new async calls.
  588. Args:
  589. func: A single Callable applied to all specified remote actors or a list
  590. of Callables, that get applied on the list of specified remote actors.
  591. In the latter case, both list of Callables and list of specified actors
  592. must have the same length. Alternatively, you can use the name of the
  593. remote method to be called, instead, or a list of remote method names.
  594. tag: A tag to identify the results from this async call.
  595. kwargs: An optional single kwargs dict or a list of kwargs dict matching the
  596. list of provided `func` or `remote_actor_ids`. In the first case (single
  597. dict), use `kwargs` on all remote calls. The latter case (list of
  598. dicts) allows you to define individualized kwarg dicts per actor.
  599. timeout_seconds: Time to wait for results from previous calls. Default is 0,
  600. meaning those requests that are already ready.
  601. return_obj_refs: Whether to return ObjectRef instead of actual results.
  602. mark_healthy: Whether to mark all those actors healthy again that are
  603. currently marked unhealthy AND that returned results from the remote
  604. call (within the given `timeout_seconds`).
  605. healthy_only: Apply `func` on known-to-be healthy actors only.
  606. remote_actor_ids: Apply func on a selected set of remote actors.
  607. ignore_ray_errors: Whether to ignore RayErrors in results.
  608. return_actor_ids: Whether to return actor IDs in the results.
  609. If True, the results will be a list of (actor_id, result) tuples.
  610. If False, the results will be a list of results.
  611. Returns:
  612. The results from previous async requests that were ready.
  613. """
  614. # First fetch any ready results from previous async calls
  615. remote_results = self.fetch_ready_async_reqs(
  616. tags=tag,
  617. timeout_seconds=timeout_seconds,
  618. return_obj_refs=return_obj_refs,
  619. mark_healthy=mark_healthy,
  620. )
  621. # Then start new async calls
  622. self.foreach_actor_async(
  623. func,
  624. tag=tag,
  625. kwargs=kwargs,
  626. healthy_only=healthy_only,
  627. remote_actor_ids=remote_actor_ids,
  628. )
  629. # Handle errors the same way as fetch_ready_async_reqs does
  630. FaultTolerantActorManager.handle_remote_call_result_errors(
  631. remote_results,
  632. ignore_ray_errors=ignore_ray_errors,
  633. )
  634. if return_actor_ids:
  635. return [(r.actor_id, r.get()) for r in remote_results.ignore_errors()]
  636. else:
  637. return [r.get() for r in remote_results.ignore_errors()]
  638. @staticmethod
  639. def handle_remote_call_result_errors(
  640. results_or_errors: RemoteCallResults,
  641. *,
  642. ignore_ray_errors: bool,
  643. ) -> None:
  644. """Checks given results for application errors and raises them if necessary.
  645. Args:
  646. results_or_errors: The results or errors to check.
  647. ignore_ray_errors: Whether to ignore RayErrors within the elements of
  648. `results_or_errors`.
  649. """
  650. for result_or_error in results_or_errors:
  651. # Good result.
  652. if result_or_error.ok:
  653. continue
  654. # RayError, but we ignore it.
  655. elif ignore_ray_errors:
  656. logger.exception(result_or_error.get())
  657. # Raise RayError.
  658. else:
  659. raise result_or_error.get()
  660. @DeveloperAPI
  661. def probe_unhealthy_actors(
  662. self,
  663. timeout_seconds: Optional[float] = None,
  664. mark_healthy: bool = False,
  665. ) -> List[int]:
  666. """Ping all unhealthy actors to try bringing them back.
  667. Args:
  668. timeout_seconds: Timeout in seconds (to avoid pinging hanging workers
  669. indefinitely).
  670. mark_healthy: Whether to mark all those actors healthy again that are
  671. currently marked unhealthy AND that respond to the `ping` remote request
  672. (within the given `timeout_seconds`).
  673. Note that actors are NOT set to unhealthy, if they simply time out,
  674. meaning take a longer time to fulfil the remote request. We only ever
  675. mark and actor unhealthy, if they return a RayActorError from the remote
  676. request.
  677. Also note that this settings is ignored if `healthy_only=True` (b/c this
  678. setting only affects actors that are currently tagged as unhealthy).
  679. Returns:
  680. A list of actor IDs that were restored by the `ping.remote()` call PLUS
  681. those actors that were previously restored via other remote requests.
  682. The cached set of such previously restored actors will be erased in this
  683. call.
  684. """
  685. # Collect recently restored actors (from `self._fetch_result` calls other than
  686. # the one triggered here via the `ping`).
  687. already_restored_actors = list(self._restored_actors)
  688. # Which actors are currently marked unhealthy?
  689. unhealthy_actor_ids = [
  690. actor_id
  691. for actor_id in self.actor_ids()
  692. if not self.is_actor_healthy(actor_id)
  693. ]
  694. # Some unhealthy actors -> `ping()` all of them to trigger a new fetch and
  695. # gather the just restored ones (b/c of a successful `ping` response).
  696. just_restored_actors = []
  697. if unhealthy_actor_ids:
  698. remote_results = self.foreach_actor(
  699. func=lambda actor: actor.ping(),
  700. remote_actor_ids=unhealthy_actor_ids,
  701. healthy_only=False, # We specifically want to ping unhealthy actors.
  702. timeout_seconds=timeout_seconds,
  703. return_obj_refs=False,
  704. mark_healthy=mark_healthy,
  705. )
  706. just_restored_actors = [
  707. result.actor_id for result in remote_results if result.ok
  708. ]
  709. # Clear out previously restored actors (b/c of other successful request
  710. # responses, outside of this method).
  711. self._restored_actors.clear()
  712. # Return all restored actors (previously and just).
  713. return already_restored_actors + just_restored_actors
  714. def _call_actors(
  715. self,
  716. func: Union[Callable[[Any], Any], List[Callable[[Any], Any]], str, List[str]],
  717. *,
  718. kwargs: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
  719. remote_actor_ids: List[int] = None,
  720. ) -> List[ray.ObjectRef]:
  721. """Apply functions on a list of remote actors.
  722. Args:
  723. func: A single Callable applied to all specified remote actors or a list
  724. of Callables, that get applied on the list of specified remote actors.
  725. In the latter case, both list of Callables and list of specified actors
  726. must have the same length. Alternatively, you can use the name of the
  727. remote method to be called, instead, or a list of remote method names.
  728. kwargs: An optional single kwargs dict or a list of kwargs dict matching the
  729. list of provided `func` or `remote_actor_ids`. In the first case (single
  730. dict), use `kwargs` on all remote calls. The latter case (list of
  731. dicts) allows you to define individualized kwarg dicts per actor.
  732. remote_actor_ids: Apply func on this selected set of remote actors.
  733. Returns:
  734. A list of ObjectRefs returned from the remote calls.
  735. """
  736. if remote_actor_ids is None:
  737. remote_actor_ids = self.actor_ids()
  738. calls = []
  739. if isinstance(func, list):
  740. assert len(remote_actor_ids) == len(
  741. func
  742. ), "Funcs must have the same number of callables as actor indices."
  743. assert isinstance(
  744. kwargs, list
  745. ), "If func is a list of functions, kwargs has to be a list of kwargs."
  746. for i, (raid, f) in enumerate(zip(remote_actor_ids, func)):
  747. if isinstance(f, str):
  748. calls.append(
  749. getattr(self._actors[raid], f).remote(
  750. **(
  751. kwargs[i]
  752. if isinstance(kwargs, list)
  753. else (kwargs or {})
  754. )
  755. )
  756. )
  757. else:
  758. calls.append(self._actors[raid].apply.remote(f))
  759. elif isinstance(func, str):
  760. for i, raid in enumerate(remote_actor_ids):
  761. calls.append(
  762. getattr(self._actors[raid], func).remote(
  763. **(kwargs[i] if isinstance(kwargs, list) else (kwargs or {}))
  764. )
  765. )
  766. else:
  767. for raid in remote_actor_ids:
  768. calls.append(self._actors[raid].apply.remote(func=func, **kwargs or {}))
  769. return calls
  770. @DeveloperAPI
  771. def _fetch_result(
  772. self,
  773. *,
  774. remote_actor_ids: List[int],
  775. remote_calls: List[ray.ObjectRef],
  776. tags: List[str],
  777. timeout_seconds: Optional[float] = None,
  778. return_obj_refs: bool = False,
  779. mark_healthy: bool = False,
  780. ) -> Tuple[List[ray.ObjectRef], RemoteCallResults]:
  781. """Try fetching results from remote actor calls.
  782. Mark whether an actor is healthy or not accordingly.
  783. Args:
  784. remote_actor_ids: IDs of the actors these remote
  785. calls were fired against.
  786. remote_calls: List of remote calls to fetch.
  787. tags: List of tags used for identifying the remote calls.
  788. timeout_seconds: Timeout (in sec) for the ray.wait() call. Default is None,
  789. meaning wait indefinitely for all results.
  790. return_obj_refs: Whether to return ObjectRef instead of actual results.
  791. mark_healthy: Whether to mark certain actors healthy based on the results
  792. of these remote calls. Useful, for example, to make sure actors
  793. do not come back without proper state restoration.
  794. Returns:
  795. A list of ready ObjectRefs mapping to the results of those calls.
  796. """
  797. # Notice that we do not return the refs to any unfinished calls to the
  798. # user, since it is not safe to handle such remote actor calls outside the
  799. # context of this actor manager. These requests are simply dropped.
  800. timeout = float(timeout_seconds) if timeout_seconds is not None else None
  801. # This avoids calling ray.init() in the case of 0 remote calls.
  802. # This is useful if the number of remote workers is 0.
  803. if not remote_calls:
  804. return [], RemoteCallResults()
  805. readies, _ = ray.wait(
  806. remote_calls,
  807. num_returns=len(remote_calls),
  808. timeout=timeout,
  809. # Make sure remote results are fetched locally in parallel.
  810. fetch_local=not return_obj_refs,
  811. )
  812. # Remote data should already be fetched to local object store at this point.
  813. remote_results = RemoteCallResults()
  814. for ready in readies:
  815. # Find the corresponding actor ID for this remote call.
  816. actor_id = remote_actor_ids[remote_calls.index(ready)]
  817. tag = tags[remote_calls.index(ready)]
  818. # If caller wants ObjectRefs, return directly without resolving.
  819. if return_obj_refs:
  820. remote_results.add_result(actor_id, ResultOrError(result=ready), tag)
  821. continue
  822. # Try getting the ready results.
  823. try:
  824. result = ray.get(ready)
  825. # Any error type other than `RayError` happening during ray.get() ->
  826. # Throw exception right here (we don't know how to handle these non-remote
  827. # worker issues and should therefore crash).
  828. except RayError as e:
  829. # Return error to the user.
  830. remote_results.add_result(actor_id, ResultOrError(error=e), tag)
  831. # Mark the actor as unhealthy, take it out of service, and wait for
  832. # Ray Core to restore it.
  833. if self.is_actor_healthy(actor_id):
  834. logger.error(
  835. f"Ray error ({str(e)}), taking actor {actor_id} out of service."
  836. )
  837. self.set_actor_state(actor_id, healthy=False)
  838. # If no errors, add result to `RemoteCallResults` to be returned.
  839. else:
  840. # Return valid result to the user.
  841. remote_results.add_result(actor_id, ResultOrError(result=result), tag)
  842. # Actor came back from an unhealthy state. Mark this actor as healthy
  843. # and add it to our healthy set.
  844. if mark_healthy and not self.is_actor_healthy(actor_id):
  845. logger.warning(
  846. f"Bringing previously unhealthy, now-healthy actor {actor_id} "
  847. "back into service."
  848. )
  849. self.set_actor_state(actor_id, healthy=True)
  850. self._num_actor_restarts += 1
  851. # Make sure, to-be-returned results are sound.
  852. assert len(readies) == len(remote_results)
  853. return readies, remote_results
  854. def _filter_by_healthy_state(
  855. self,
  856. *,
  857. func: Union[Callable[[Any], Any], List[Callable[[Any], Any]]],
  858. kwargs: Optional[Union[Dict, List[Dict]]] = None,
  859. remote_actor_ids: List[int],
  860. ):
  861. """Filter out func and remote worker ids by actor state.
  862. Args:
  863. func: A single, or a list of Callables.
  864. kwargs: An optional single kwargs dict or a list of kwargs dicts matching
  865. the list of provided `func` or `remote_actor_ids`. In case of a single
  866. dict, uses `kwargs` on all remote calls. In case of a list of dicts,
  867. the given kwarg dicts are per actor `func` or per `remote_actor_ids`.
  868. remote_actor_ids: IDs of potential remote workers to apply func on.
  869. Returns:
  870. A tuple of (filtered func, filtered remote worker ids).
  871. """
  872. if isinstance(func, list):
  873. assert len(remote_actor_ids) == len(
  874. func
  875. ), "Func must have the same number of callables as remote actor ids."
  876. # We are given a list of functions to apply.
  877. # Need to filter the functions together with worker IDs.
  878. temp_func = []
  879. temp_remote_actor_ids = []
  880. temp_kwargs = []
  881. for i, (f, raid) in enumerate(zip(func, remote_actor_ids)):
  882. if self.is_actor_healthy(raid):
  883. k = kwargs[i] if isinstance(kwargs, list) else (kwargs or {})
  884. temp_func.append(f)
  885. temp_kwargs.append(k)
  886. temp_remote_actor_ids.append(raid)
  887. func = temp_func
  888. kwargs = temp_kwargs
  889. remote_actor_ids = temp_remote_actor_ids
  890. else:
  891. # Simply filter the worker IDs.
  892. remote_actor_ids = [i for i in remote_actor_ids if self.is_actor_healthy(i)]
  893. return func, kwargs, remote_actor_ids
  894. def _filter_calls_by_tag(
  895. self, tags: Optional[Union[str, List[str], Tuple[str, ...]]] = None
  896. ) -> Tuple[List[ray.ObjectRef], List[ActorHandle], List[str]]:
  897. """Return all the in flight requests that match the given tags, if any.
  898. Args:
  899. tags: A str or a list/tuple of str. If tags is empty or None, return all the in
  900. flight requests.
  901. Returns:
  902. A tuple consisting of a list of the remote calls that match the tag(s),
  903. a list of the corresponding remote actor IDs for these calls (same length),
  904. and a list of the tags corresponding to these calls (same length).
  905. """
  906. if tags is None:
  907. tags = set()
  908. elif isinstance(tags, str):
  909. tags = {tags}
  910. elif isinstance(tags, (list, tuple)):
  911. tags = set(tags)
  912. else:
  913. raise ValueError(
  914. f"tags must be either a str or a list/tuple of str, got {type(tags)}."
  915. )
  916. remote_calls = []
  917. remote_actor_ids = []
  918. valid_tags = []
  919. for call, (tag, actor_id) in self._in_flight_req_to_actor_id.items():
  920. # the default behavior is to return all ready results.
  921. if len(tags) == 0 or tag in tags:
  922. remote_calls.append(call)
  923. remote_actor_ids.append(actor_id)
  924. valid_tags.append(tag)
  925. return remote_calls, remote_actor_ids, valid_tags
  926. def _remove_async_state(self, actor_id: int):
  927. """Remove internal async state of for a given actor.
  928. This is called when an actor is removed from the pool or being marked
  929. unhealthy.
  930. Args:
  931. actor_id: The id of the actor.
  932. """
  933. # Remove any outstanding async requests for this actor.
  934. # Use `list` here to not change a looped generator while we mutate the
  935. # underlying dict.
  936. for req, (tag, id) in list(self._in_flight_req_to_actor_id.items()):
  937. if id == actor_id:
  938. del self._in_flight_req_to_actor_id[req]
  939. # Clear all tag-based request counts for this actor
  940. if actor_id in self._remote_actor_states:
  941. self._remote_actor_states[
  942. actor_id
  943. ].num_in_flight_async_requests_by_tag.clear()
  944. def actors(self):
  945. # TODO(jungong) : remove this API once EnvRunnerGroup.remote_workers()
  946. # and EnvRunnerGroup._remote_workers() are removed.
  947. return self._actors