pool.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. import collections
  2. import copy
  3. import gc
  4. import itertools
  5. import logging
  6. import os
  7. import queue
  8. import sys
  9. import threading
  10. import time
  11. from multiprocessing import TimeoutError
  12. from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Tuple
  13. import ray
  14. from ray._common.usage import usage_lib
  15. from ray.util import log_once
  16. try:
  17. from joblib._parallel_backends import SafeFunction
  18. from joblib.parallel import BatchedCalls, parallel_backend
  19. except ImportError:
  20. BatchedCalls = None
  21. parallel_backend = None
  22. SafeFunction = None
  23. logger = logging.getLogger(__name__)
  24. RAY_ADDRESS_ENV = "RAY_ADDRESS"
  25. def _put_in_dict_registry(
  26. obj: Any, registry_hashable: Dict[Hashable, ray.ObjectRef]
  27. ) -> ray.ObjectRef:
  28. if obj not in registry_hashable:
  29. ret = ray.put(obj)
  30. registry_hashable[obj] = ret
  31. else:
  32. ret = registry_hashable[obj]
  33. return ret
  34. def _put_in_list_registry(
  35. obj: Any, registry: List[Tuple[Any, ray.ObjectRef]]
  36. ) -> ray.ObjectRef:
  37. try:
  38. ret = next((ref for o, ref in registry if o is obj))
  39. except StopIteration:
  40. ret = ray.put(obj)
  41. registry.append((obj, ret))
  42. return ret
  43. def ray_put_if_needed(
  44. obj: Any,
  45. registry: Optional[List[Tuple[Any, ray.ObjectRef]]] = None,
  46. registry_hashable: Optional[Dict[Hashable, ray.ObjectRef]] = None,
  47. ) -> ray.ObjectRef:
  48. """ray.put obj in object store if it's not an ObjRef and bigger than 100 bytes,
  49. with support for list and dict registries"""
  50. if isinstance(obj, ray.ObjectRef) or sys.getsizeof(obj) < 100:
  51. return obj
  52. ret = obj
  53. if registry_hashable is not None:
  54. try:
  55. ret = _put_in_dict_registry(obj, registry_hashable)
  56. except TypeError:
  57. if registry is not None:
  58. ret = _put_in_list_registry(obj, registry)
  59. elif registry is not None:
  60. ret = _put_in_list_registry(obj, registry)
  61. return ret
  62. def ray_get_if_needed(obj: Any) -> Any:
  63. """If obj is an ObjectRef, do ray.get, otherwise return obj"""
  64. if isinstance(obj, ray.ObjectRef):
  65. return ray.get(obj)
  66. return obj
  67. if BatchedCalls is not None:
  68. class RayBatchedCalls(BatchedCalls):
  69. """Joblib's BatchedCalls with basic Ray object store management
  70. This functionality is provided through the put_items_in_object_store,
  71. which uses external registries (list and dict) containing objects
  72. and their ObjectRefs."""
  73. def put_items_in_object_store(
  74. self,
  75. registry: Optional[List[Tuple[Any, ray.ObjectRef]]] = None,
  76. registry_hashable: Optional[Dict[Hashable, ray.ObjectRef]] = None,
  77. ):
  78. """Puts all applicable (kw)args in self.items in object store
  79. Takes two registries - list for unhashable objects and dict
  80. for hashable objects. The registries are a part of a Pool object.
  81. The method iterates through all entries in items list (usually,
  82. there will be only one, but the number depends on joblib Parallel
  83. settings) and puts all of the args and kwargs into the object
  84. store, updating the registries.
  85. If an arg or kwarg is already in a registry, it will not be
  86. put again, and instead, the cached object ref will be used."""
  87. new_items = []
  88. for func, args, kwargs in self.items:
  89. args = [
  90. ray_put_if_needed(arg, registry, registry_hashable) for arg in args
  91. ]
  92. kwargs = {
  93. k: ray_put_if_needed(v, registry, registry_hashable)
  94. for k, v in kwargs.items()
  95. }
  96. new_items.append((func, args, kwargs))
  97. self.items = new_items
  98. def __call__(self):
  99. # Exactly the same as in BatchedCalls, with the
  100. # difference being that it gets args and kwargs from
  101. # object store (which have been put in there by
  102. # put_items_in_object_store)
  103. # Set the default nested backend to self._backend but do
  104. # not set the change the default number of processes to -1
  105. with parallel_backend(self._backend, n_jobs=self._n_jobs):
  106. return [
  107. func(
  108. *[ray_get_if_needed(arg) for arg in args],
  109. **{k: ray_get_if_needed(v) for k, v in kwargs.items()},
  110. )
  111. for func, args, kwargs in self.items
  112. ]
  113. def __reduce__(self):
  114. # Exactly the same as in BatchedCalls, with the
  115. # difference being that it returns RayBatchedCalls
  116. # instead
  117. if self._reducer_callback is not None:
  118. self._reducer_callback()
  119. # no need pickle the callback.
  120. return (
  121. RayBatchedCalls,
  122. (self.items, (self._backend, self._n_jobs), None, self._pickle_cache),
  123. )
  124. else:
  125. RayBatchedCalls = None
  126. # Helper function to divide a by b and round the result up.
  127. def div_round_up(a, b):
  128. return -(-a // b)
  129. class PoolTaskError(Exception):
  130. def __init__(self, underlying):
  131. self.underlying = underlying
  132. class ResultThread(threading.Thread):
  133. """Thread that collects results from distributed actors.
  134. It winds down when either:
  135. - A pre-specified number of objects has been processed
  136. - When the END_SENTINEL (submitted through self.add_object_ref())
  137. has been received and all objects received before that have been
  138. processed.
  139. Initialize the thread with total_object_refs = float('inf') to wait for the
  140. END_SENTINEL.
  141. Args:
  142. object_refs (List[RayActorObjectRefs]): ObjectRefs to Ray Actor calls.
  143. Thread tracks whether they are ready. More ObjectRefs may be added
  144. with add_object_ref (or _add_object_ref internally) until the object
  145. count reaches total_object_refs.
  146. single_result: Should be True if the thread is managing function
  147. with a single result (like apply_async). False if the thread is managing
  148. a function with a List of results.
  149. callback: called only once at the end of the thread
  150. if no results were errors. If single_result=True, and result is
  151. not an error, callback is invoked with the result as the only
  152. argument. If single_result=False, callback is invoked with
  153. a list of all the results as the only argument.
  154. error_callback: called only once on the first result
  155. that errors. Should take an Exception as the only argument.
  156. If no result errors, this callback is not called.
  157. total_object_refs: Number of ObjectRefs that this thread
  158. expects to be ready. May be more than len(object_refs) since
  159. more ObjectRefs can be submitted after the thread starts.
  160. If None, defaults to len(object_refs). If float("inf"), thread runs
  161. until END_SENTINEL (submitted through self.add_object_ref())
  162. has been received and all objects received before that have
  163. been processed.
  164. """
  165. END_SENTINEL = None
  166. def __init__(
  167. self,
  168. object_refs: list,
  169. single_result: bool = False,
  170. callback: callable = None,
  171. error_callback: callable = None,
  172. total_object_refs: Optional[int] = None,
  173. ):
  174. threading.Thread.__init__(self, daemon=True)
  175. self._got_error = False
  176. self._object_refs = []
  177. self._num_ready = 0
  178. self._results = []
  179. self._ready_index_queue = queue.Queue()
  180. self._single_result = single_result
  181. self._callback = callback
  182. self._error_callback = error_callback
  183. self._total_object_refs = total_object_refs or len(object_refs)
  184. self._indices = {}
  185. # Thread-safe queue used to add ObjectRefs to fetch after creating
  186. # this thread (used to lazily submit for imap and imap_unordered).
  187. self._new_object_refs = queue.Queue()
  188. for object_ref in object_refs:
  189. self._add_object_ref(object_ref)
  190. def _add_object_ref(self, object_ref):
  191. self._indices[object_ref] = len(self._object_refs)
  192. self._object_refs.append(object_ref)
  193. self._results.append(None)
  194. def add_object_ref(self, object_ref):
  195. self._new_object_refs.put(object_ref)
  196. def run(self):
  197. unready = copy.copy(self._object_refs)
  198. aggregated_batch_results = []
  199. # Run for a specific number of objects if self._total_object_refs is finite.
  200. # Otherwise, process all objects received prior to the stop signal, given by
  201. # self.add_object(END_SENTINEL).
  202. while self._num_ready < self._total_object_refs:
  203. # Get as many new IDs from the queue as possible without blocking,
  204. # unless we have no IDs to wait on, in which case we block.
  205. ready_id = None
  206. while ready_id is None:
  207. try:
  208. block = len(unready) == 0
  209. new_object_ref = self._new_object_refs.get(block=block)
  210. if new_object_ref is self.END_SENTINEL:
  211. # Receiving the END_SENTINEL object is the signal to stop.
  212. # Store the total number of objects.
  213. self._total_object_refs = len(self._object_refs)
  214. else:
  215. self._add_object_ref(new_object_ref)
  216. unready.append(new_object_ref)
  217. except queue.Empty:
  218. # queue.Empty means no result was retrieved if block=False.
  219. pass
  220. # Check if any of the available IDs are done. The timeout is required
  221. # here to periodically check for new IDs from self._new_object_refs.
  222. # NOTE(edoakes): the choice of a 100ms timeout here is arbitrary. Too
  223. # low of a timeout would cause higher overhead from busy spinning and
  224. # too high would cause higher tail latency to fetch the first result in
  225. # some cases.
  226. ready, unready = ray.wait(unready, num_returns=1, timeout=0.1)
  227. if len(ready) > 0:
  228. ready_id = ready[0]
  229. try:
  230. batch = ray.get(ready_id)
  231. except ray.exceptions.RayError as e:
  232. batch = [e]
  233. # The exception callback is called only once on the first result
  234. # that errors. If no result errors, it is never called.
  235. if not self._got_error:
  236. for result in batch:
  237. if isinstance(result, Exception):
  238. self._got_error = True
  239. if self._error_callback is not None:
  240. self._error_callback(result)
  241. break
  242. else:
  243. aggregated_batch_results.append(result)
  244. self._num_ready += 1
  245. self._results[self._indices[ready_id]] = batch
  246. self._ready_index_queue.put(self._indices[ready_id])
  247. # The regular callback is called only once on the entire List of
  248. # results as long as none of the results were errors. If any results
  249. # were errors, the regular callback is never called; instead, the
  250. # exception callback is called on the first erroring result.
  251. #
  252. # This callback is called outside the while loop to ensure that it's
  253. # called on the entire list of results– not just a single batch.
  254. if not self._got_error and self._callback is not None:
  255. if not self._single_result:
  256. self._callback(aggregated_batch_results)
  257. else:
  258. # On a thread handling a function with a single result
  259. # (e.g. apply_async), we call the callback on just that result
  260. # instead of on a list encaspulating that result
  261. self._callback(aggregated_batch_results[0])
  262. def got_error(self):
  263. # Should only be called after the thread finishes.
  264. return self._got_error
  265. def result(self, index):
  266. # Should only be called on results that are ready.
  267. return self._results[index]
  268. def results(self):
  269. # Should only be called after the thread finishes.
  270. return self._results
  271. def next_ready_index(self, timeout=None):
  272. try:
  273. return self._ready_index_queue.get(timeout=timeout)
  274. except queue.Empty:
  275. # queue.Queue signals a timeout by raising queue.Empty.
  276. raise TimeoutError
  277. class AsyncResult:
  278. """An asynchronous interface to task results.
  279. This should not be constructed directly.
  280. """
  281. def __init__(
  282. self, chunk_object_refs, callback=None, error_callback=None, single_result=False
  283. ):
  284. self._single_result = single_result
  285. self._result_thread = ResultThread(
  286. chunk_object_refs, single_result, callback, error_callback
  287. )
  288. self._result_thread.start()
  289. def wait(self, timeout=None):
  290. """
  291. Returns once the result is ready or the timeout expires (does not
  292. raise TimeoutError).
  293. Args:
  294. timeout: timeout in milliseconds.
  295. """
  296. self._result_thread.join(timeout)
  297. def get(self, timeout=None):
  298. self.wait(timeout)
  299. if self._result_thread.is_alive():
  300. raise TimeoutError
  301. results = []
  302. for batch in self._result_thread.results():
  303. for result in batch:
  304. if isinstance(result, PoolTaskError):
  305. raise result.underlying
  306. elif isinstance(result, Exception):
  307. raise result
  308. results.extend(batch)
  309. if self._single_result:
  310. return results[0]
  311. return results
  312. def ready(self):
  313. """
  314. Returns true if the result is ready, else false if the tasks are still
  315. running.
  316. """
  317. return not self._result_thread.is_alive()
  318. def successful(self):
  319. """
  320. Returns true if none of the submitted tasks errored, else false. Should
  321. only be called once the result is ready (can be checked using `ready`).
  322. """
  323. if not self.ready():
  324. raise ValueError(f"{self!r} not ready")
  325. return not self._result_thread.got_error()
  326. class IMapIterator:
  327. """Base class for OrderedIMapIterator and UnorderedIMapIterator."""
  328. def __init__(self, pool, func, iterable, chunksize=None):
  329. self._pool = pool
  330. self._func = func
  331. self._next_chunk_index = 0
  332. self._finished_iterating = False
  333. # List of bools indicating if the given chunk is ready or not for all
  334. # submitted chunks. Ordering mirrors that in the in the ResultThread.
  335. self._submitted_chunks = []
  336. self._ready_objects = collections.deque()
  337. self._iterator = iter(iterable)
  338. if isinstance(iterable, collections.abc.Iterator):
  339. # Got iterator (which has no len() function).
  340. # Make default chunksize 1 instead of using _calculate_chunksize().
  341. # Indicate unknown queue length, requiring explicit stopping.
  342. self._chunksize = chunksize or 1
  343. result_list_size = float("inf")
  344. else:
  345. self._chunksize = chunksize or pool._calculate_chunksize(iterable)
  346. result_list_size = div_round_up(len(iterable), chunksize)
  347. self._result_thread = ResultThread([], total_object_refs=result_list_size)
  348. self._result_thread.start()
  349. for _ in range(len(self._pool._actor_pool)):
  350. self._submit_next_chunk()
  351. def _submit_next_chunk(self):
  352. # The full iterable has already been submitted, so no-op.
  353. if self._finished_iterating:
  354. return
  355. actor_index = len(self._submitted_chunks) % len(self._pool._actor_pool)
  356. chunk_iterator = itertools.islice(self._iterator, self._chunksize)
  357. # Check whether we have run out of samples.
  358. # This consumes the original iterator, so we convert to a list and back
  359. chunk_list = list(chunk_iterator)
  360. if len(chunk_list) < self._chunksize:
  361. # Reached end of self._iterator
  362. self._finished_iterating = True
  363. if len(chunk_list) == 0:
  364. # Nothing to do, return.
  365. return
  366. chunk_iterator = iter(chunk_list)
  367. new_chunk_id = self._pool._submit_chunk(
  368. self._func, chunk_iterator, self._chunksize, actor_index
  369. )
  370. self._submitted_chunks.append(False)
  371. # Wait for the result
  372. self._result_thread.add_object_ref(new_chunk_id)
  373. # If we submitted the final chunk, notify the result thread
  374. if self._finished_iterating:
  375. self._result_thread.add_object_ref(ResultThread.END_SENTINEL)
  376. def __iter__(self):
  377. return self
  378. def __next__(self):
  379. return self.next()
  380. def next(self):
  381. # Should be implemented by subclasses.
  382. raise NotImplementedError
  383. class OrderedIMapIterator(IMapIterator):
  384. """Iterator to the results of tasks submitted using `imap`.
  385. The results are returned in the same order that they were submitted, even
  386. if they don't finish in that order. Only one batch of tasks per actor
  387. process is submitted at a time - the rest are submitted as results come in.
  388. Should not be constructed directly.
  389. """
  390. def next(self, timeout=None):
  391. if len(self._ready_objects) == 0:
  392. if self._finished_iterating and (
  393. self._next_chunk_index == len(self._submitted_chunks)
  394. ):
  395. # Finish when all chunks have been dispatched and processed
  396. # Notify the calling process that the work is done.
  397. raise StopIteration
  398. # This loop will break when the next index in order is ready or
  399. # self._result_thread.next_ready_index() raises a timeout.
  400. index = -1
  401. while index != self._next_chunk_index:
  402. start = time.time()
  403. index = self._result_thread.next_ready_index(timeout=timeout)
  404. self._submit_next_chunk()
  405. self._submitted_chunks[index] = True
  406. if timeout is not None:
  407. timeout = max(0, timeout - (time.time() - start))
  408. while (
  409. self._next_chunk_index < len(self._submitted_chunks)
  410. and self._submitted_chunks[self._next_chunk_index]
  411. ):
  412. for result in self._result_thread.result(self._next_chunk_index):
  413. self._ready_objects.append(result)
  414. self._next_chunk_index += 1
  415. return self._ready_objects.popleft()
  416. class UnorderedIMapIterator(IMapIterator):
  417. """Iterator to the results of tasks submitted using `imap`.
  418. The results are returned in the order that they finish. Only one batch of
  419. tasks per actor process is submitted at a time - the rest are submitted as
  420. results come in.
  421. Should not be constructed directly.
  422. """
  423. def next(self, timeout=None):
  424. if len(self._ready_objects) == 0:
  425. if self._finished_iterating and (
  426. self._next_chunk_index == len(self._submitted_chunks)
  427. ):
  428. # Finish when all chunks have been dispatched and processed
  429. # Notify the calling process that the work is done.
  430. raise StopIteration
  431. index = self._result_thread.next_ready_index(timeout=timeout)
  432. self._submit_next_chunk()
  433. for result in self._result_thread.result(index):
  434. self._ready_objects.append(result)
  435. self._next_chunk_index += 1
  436. return self._ready_objects.popleft()
  437. @ray.remote(num_cpus=0)
  438. class PoolActor:
  439. """Actor used to process tasks submitted to a Pool."""
  440. def __init__(self, initializer=None, initargs=None):
  441. if initializer:
  442. initargs = initargs or ()
  443. initializer(*initargs)
  444. def ping(self):
  445. # Used to wait for this actor to be initialized.
  446. pass
  447. def run_batch(self, func, batch):
  448. results = []
  449. for args, kwargs in batch:
  450. args = args or ()
  451. kwargs = kwargs or {}
  452. try:
  453. results.append(func(*args, **kwargs))
  454. except Exception as e:
  455. results.append(PoolTaskError(e))
  456. return results
  457. # https://docs.python.org/3/library/multiprocessing.html#module-multiprocessing.pool
  458. class Pool:
  459. """A pool of actor processes that is used to process tasks in parallel.
  460. Args:
  461. processes: number of actor processes to start in the pool. Defaults to
  462. the number of cores in the Ray cluster if one is already running,
  463. otherwise the number of cores on this machine.
  464. initializer: function to be run in each actor when it starts up.
  465. initargs: iterable of arguments to the initializer function.
  466. maxtasksperchild: maximum number of tasks to run in each actor process.
  467. After a process has executed this many tasks, it will be killed and
  468. replaced with a new one.
  469. ray_address: address of the Ray cluster to run on. If None, a new local
  470. Ray cluster will be started on this machine. Otherwise, this will
  471. be passed to `ray.init()` to connect to a running cluster. This may
  472. also be specified using the `RAY_ADDRESS` environment variable.
  473. ray_remote_args: arguments used to configure the Ray Actors making up
  474. the pool. See :func:`ray.remote` for details.
  475. """
  476. def __init__(
  477. self,
  478. processes: Optional[int] = None,
  479. initializer: Optional[Callable] = None,
  480. initargs: Optional[Iterable] = None,
  481. maxtasksperchild: Optional[int] = None,
  482. context: Any = None,
  483. ray_address: Optional[str] = None,
  484. ray_remote_args: Optional[Dict[str, Any]] = None,
  485. ):
  486. usage_lib.record_library_usage("util.multiprocessing.Pool")
  487. self._closed = False
  488. self._initializer = initializer
  489. self._initargs = initargs
  490. self._maxtasksperchild = maxtasksperchild or -1
  491. self._actor_deletion_ids = []
  492. self._registry: List[Tuple[Any, ray.ObjectRef]] = []
  493. self._registry_hashable: Dict[Hashable, ray.ObjectRef] = {}
  494. self._current_index = 0
  495. self._ray_remote_args = ray_remote_args or {}
  496. self._pool_actor = None
  497. if context and log_once("context_argument_warning"):
  498. logger.warning(
  499. "The 'context' argument is not supported using "
  500. "ray. Please refer to the documentation for how "
  501. "to control ray initialization."
  502. )
  503. processes = self._init_ray(processes, ray_address)
  504. self._start_actor_pool(processes)
  505. def _init_ray(self, processes=None, ray_address=None):
  506. # Initialize ray. If ray is already initialized, we do nothing.
  507. # Else, the priority is:
  508. # ray_address argument > RAY_ADDRESS > start new local cluster.
  509. if not ray.is_initialized():
  510. # Cluster mode.
  511. if ray_address is None and (
  512. RAY_ADDRESS_ENV in os.environ
  513. or ray._private.utils.read_ray_address() is not None
  514. ):
  515. init_kwargs = {}
  516. if os.environ.get(RAY_ADDRESS_ENV) == "local":
  517. init_kwargs["num_cpus"] = processes
  518. ray.init(**init_kwargs)
  519. elif ray_address is not None:
  520. init_kwargs = {}
  521. if ray_address == "local":
  522. init_kwargs["num_cpus"] = processes
  523. ray.init(address=ray_address, **init_kwargs)
  524. # Local mode.
  525. else:
  526. ray.init(num_cpus=processes)
  527. ray_cpus = int(ray._private.state.cluster_resources()["CPU"])
  528. if processes is None:
  529. processes = ray_cpus
  530. if processes <= 0:
  531. raise ValueError("Processes in the pool must be >0.")
  532. if ray_cpus < processes:
  533. raise ValueError(
  534. "Tried to start a pool with {} processes on an "
  535. "existing ray cluster, but there are only {} "
  536. "CPUs in the ray cluster.".format(processes, ray_cpus)
  537. )
  538. return processes
  539. def _start_actor_pool(self, processes):
  540. self._pool_actor = None
  541. self._actor_pool = [self._new_actor_entry() for _ in range(processes)]
  542. ray.get([actor.ping.remote() for actor, _ in self._actor_pool])
  543. def _wait_for_stopping_actors(self, timeout=None):
  544. if len(self._actor_deletion_ids) == 0:
  545. return
  546. if timeout is not None:
  547. timeout = float(timeout)
  548. _, deleting = ray.wait(
  549. self._actor_deletion_ids,
  550. num_returns=len(self._actor_deletion_ids),
  551. timeout=timeout,
  552. )
  553. self._actor_deletion_ids = deleting
  554. def _stop_actor(self, actor):
  555. # Check and clean up any outstanding IDs corresponding to deletions.
  556. self._wait_for_stopping_actors(timeout=0.0)
  557. # The deletion task will block until the actor has finished executing
  558. # all pending tasks.
  559. self._actor_deletion_ids.append(actor.__ray_terminate__.remote())
  560. def _new_actor_entry(self):
  561. # NOTE(edoakes): The initializer function can't currently be used to
  562. # modify the global namespace (e.g., import packages or set globals)
  563. # due to a limitation in cloudpickle.
  564. # Cache the PoolActor with options
  565. if not self._pool_actor:
  566. self._pool_actor = PoolActor.options(**self._ray_remote_args)
  567. return (self._pool_actor.remote(self._initializer, self._initargs), 0)
  568. def _next_actor_index(self):
  569. if self._current_index == len(self._actor_pool) - 1:
  570. self._current_index = 0
  571. else:
  572. self._current_index += 1
  573. return self._current_index
  574. # Batch should be a list of tuples: (args, kwargs).
  575. def _run_batch(self, actor_index, func, batch):
  576. actor, count = self._actor_pool[actor_index]
  577. object_ref = actor.run_batch.remote(func, batch)
  578. count += 1
  579. assert self._maxtasksperchild == -1 or count <= self._maxtasksperchild
  580. if count == self._maxtasksperchild:
  581. self._stop_actor(actor)
  582. actor, count = self._new_actor_entry()
  583. self._actor_pool[actor_index] = (actor, count)
  584. return object_ref
  585. def apply(
  586. self,
  587. func: Callable,
  588. args: Optional[Tuple] = None,
  589. kwargs: Optional[Dict] = None,
  590. ):
  591. """Run the given function on a random actor process and return the
  592. result synchronously.
  593. Args:
  594. func: function to run.
  595. args: optional arguments to the function.
  596. kwargs: optional keyword arguments to the function.
  597. Returns:
  598. The result.
  599. """
  600. return self.apply_async(func, args, kwargs).get()
  601. def apply_async(
  602. self,
  603. func: Callable,
  604. args: Optional[Tuple] = None,
  605. kwargs: Optional[Dict] = None,
  606. callback: Callable[[Any], None] = None,
  607. error_callback: Callable[[Exception], None] = None,
  608. ):
  609. """Run the given function on a random actor process and return an
  610. asynchronous interface to the result.
  611. Args:
  612. func: function to run.
  613. args: optional arguments to the function.
  614. kwargs: optional keyword arguments to the function.
  615. callback: callback to be executed on the result once it is finished
  616. only if it succeeds.
  617. error_callback: callback to be executed the result once it is
  618. finished only if the task errors. The exception raised by the
  619. task will be passed as the only argument to the callback.
  620. Returns:
  621. AsyncResult containing the result.
  622. """
  623. self._check_running()
  624. func = self._convert_to_ray_batched_calls_if_needed(func)
  625. object_ref = self._run_batch(self._next_actor_index(), func, [(args, kwargs)])
  626. return AsyncResult([object_ref], callback, error_callback, single_result=True)
  627. def _convert_to_ray_batched_calls_if_needed(self, func: Callable) -> Callable:
  628. """Convert joblib's BatchedCalls to RayBatchedCalls for ObjectRef caching.
  629. This converts joblib's BatchedCalls callable, which is a collection of
  630. functions with their args and kwargs to be ran sequentially in an
  631. Actor, to a RayBatchedCalls callable, which provides identical
  632. functionality in addition to a method which ensures that common
  633. args and kwargs are put into the object store just once, saving time
  634. and memory. That method is then ran.
  635. If func is not a BatchedCalls instance, it is returned without changes.
  636. The ObjectRefs are cached inside two registries (_registry and
  637. _registry_hashable), which are common for the entire Pool and are
  638. cleaned on close."""
  639. if RayBatchedCalls is None:
  640. return func
  641. orginal_func = func
  642. # SafeFunction is a Python 2 leftover and can be
  643. # safely removed.
  644. if isinstance(func, SafeFunction):
  645. func = func.func
  646. if isinstance(func, BatchedCalls):
  647. func = RayBatchedCalls(
  648. func.items,
  649. (func._backend, func._n_jobs),
  650. func._reducer_callback,
  651. func._pickle_cache,
  652. )
  653. # go through all the items and replace args and kwargs with
  654. # ObjectRefs, caching them in registries
  655. func.put_items_in_object_store(self._registry, self._registry_hashable)
  656. else:
  657. func = orginal_func
  658. return func
  659. def _calculate_chunksize(self, iterable):
  660. chunksize, extra = divmod(len(iterable), len(self._actor_pool) * 4)
  661. if extra:
  662. chunksize += 1
  663. return chunksize
  664. def _submit_chunk(self, func, iterator, chunksize, actor_index, unpack_args=False):
  665. chunk = []
  666. while len(chunk) < chunksize:
  667. try:
  668. args = next(iterator)
  669. if not unpack_args:
  670. args = (args,)
  671. chunk.append((args, {}))
  672. except StopIteration:
  673. break
  674. # Nothing to submit. The caller should prevent this.
  675. assert len(chunk) > 0
  676. return self._run_batch(actor_index, func, chunk)
  677. def _chunk_and_run(self, func, iterable, chunksize=None, unpack_args=False):
  678. if not hasattr(iterable, "__len__"):
  679. iterable = list(iterable)
  680. if chunksize is None:
  681. chunksize = self._calculate_chunksize(iterable)
  682. iterator = iter(iterable)
  683. chunk_object_refs = []
  684. while len(chunk_object_refs) * chunksize < len(iterable):
  685. actor_index = len(chunk_object_refs) % len(self._actor_pool)
  686. chunk_object_refs.append(
  687. self._submit_chunk(
  688. func, iterator, chunksize, actor_index, unpack_args=unpack_args
  689. )
  690. )
  691. return chunk_object_refs
  692. def _map_async(
  693. self,
  694. func,
  695. iterable,
  696. chunksize=None,
  697. unpack_args=False,
  698. callback=None,
  699. error_callback=None,
  700. ):
  701. self._check_running()
  702. object_refs = self._chunk_and_run(
  703. func, iterable, chunksize=chunksize, unpack_args=unpack_args
  704. )
  705. return AsyncResult(object_refs, callback, error_callback)
  706. def map(self, func: Callable, iterable: Iterable, chunksize: Optional[int] = None):
  707. """Run the given function on each element in the iterable round-robin
  708. on the actor processes and return the results synchronously.
  709. Args:
  710. func: function to run.
  711. iterable: iterable of objects to be passed as the sole argument to
  712. func.
  713. chunksize: number of tasks to submit as a batch to each actor
  714. process. If unspecified, a suitable chunksize will be chosen.
  715. Returns:
  716. A list of results.
  717. """
  718. return self._map_async(
  719. func, iterable, chunksize=chunksize, unpack_args=False
  720. ).get()
  721. def map_async(
  722. self,
  723. func: Callable,
  724. iterable: Iterable,
  725. chunksize: Optional[int] = None,
  726. callback: Callable[[List], None] = None,
  727. error_callback: Callable[[Exception], None] = None,
  728. ):
  729. """Run the given function on each element in the iterable round-robin
  730. on the actor processes and return an asynchronous interface to the
  731. results.
  732. Args:
  733. func: function to run.
  734. iterable: iterable of objects to be passed as the only argument to
  735. func.
  736. chunksize: number of tasks to submit as a batch to each actor
  737. process. If unspecified, a suitable chunksize will be chosen.
  738. callback: Will only be called if none of the results were errors,
  739. and will only be called once after all results are finished.
  740. A Python List of all the finished results will be passed as the
  741. only argument to the callback.
  742. error_callback: callback executed on the first errored result.
  743. The Exception raised by the task will be passed as the only
  744. argument to the callback.
  745. Returns:
  746. AsyncResult
  747. """
  748. return self._map_async(
  749. func,
  750. iterable,
  751. chunksize=chunksize,
  752. unpack_args=False,
  753. callback=callback,
  754. error_callback=error_callback,
  755. )
  756. def starmap(self, func, iterable, chunksize=None):
  757. """Same as `map`, but unpacks each element of the iterable as the
  758. arguments to func like: [func(*args) for args in iterable].
  759. """
  760. return self._map_async(
  761. func, iterable, chunksize=chunksize, unpack_args=True
  762. ).get()
  763. def starmap_async(
  764. self,
  765. func: Callable,
  766. iterable: Iterable,
  767. callback: Callable[[List], None] = None,
  768. error_callback: Callable[[Exception], None] = None,
  769. ):
  770. """Same as `map_async`, but unpacks each element of the iterable as the
  771. arguments to func like: [func(*args) for args in iterable].
  772. """
  773. return self._map_async(
  774. func,
  775. iterable,
  776. unpack_args=True,
  777. callback=callback,
  778. error_callback=error_callback,
  779. )
  780. def imap(self, func: Callable, iterable: Iterable, chunksize: Optional[int] = 1):
  781. """Same as `map`, but only submits one batch of tasks to each actor
  782. process at a time.
  783. This can be useful if the iterable of arguments is very large or each
  784. task's arguments consumes a large amount of resources.
  785. The results are returned in the order corresponding to their arguments
  786. in the iterable.
  787. Returns:
  788. OrderedIMapIterator
  789. """
  790. self._check_running()
  791. return OrderedIMapIterator(self, func, iterable, chunksize=chunksize)
  792. def imap_unordered(
  793. self, func: Callable, iterable: Iterable, chunksize: Optional[int] = 1
  794. ):
  795. """Same as `map`, but only submits one batch of tasks to each actor
  796. process at a time.
  797. This can be useful if the iterable of arguments is very large or each
  798. task's arguments consumes a large amount of resources.
  799. The results are returned in the order that they finish.
  800. Returns:
  801. UnorderedIMapIterator
  802. """
  803. self._check_running()
  804. return UnorderedIMapIterator(self, func, iterable, chunksize=chunksize)
  805. def _check_running(self):
  806. if self._closed:
  807. raise ValueError("Pool not running")
  808. def __enter__(self):
  809. self._check_running()
  810. return self
  811. def __exit__(self, exc_type, exc_val, exc_tb):
  812. self.terminate()
  813. def close(self):
  814. """Close the pool.
  815. Prevents any more tasks from being submitted on the pool but allows
  816. outstanding work to finish.
  817. """
  818. self._registry.clear()
  819. self._registry_hashable.clear()
  820. for actor, _ in self._actor_pool:
  821. self._stop_actor(actor)
  822. self._closed = True
  823. gc.collect()
  824. def terminate(self):
  825. """Close the pool.
  826. Prevents any more tasks from being submitted on the pool and stops
  827. outstanding work.
  828. """
  829. if not self._closed:
  830. self.close()
  831. for actor, _ in self._actor_pool:
  832. ray.kill(actor)
  833. def join(self):
  834. """Wait for the actors in a closed pool to exit.
  835. If the pool was closed using `close`, this will return once all
  836. outstanding work is completed.
  837. If the pool was closed using `terminate`, this will return quickly.
  838. """
  839. if not self._closed:
  840. raise ValueError("Pool is still running")
  841. self._wait_for_stopping_actors()