common.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681
  1. import asyncio
  2. import concurrent
  3. import sys
  4. import threading
  5. import time
  6. from dataclasses import dataclass
  7. from typing import (
  8. TYPE_CHECKING,
  9. Any,
  10. Callable,
  11. Dict,
  12. List,
  13. NamedTuple,
  14. Optional,
  15. Tuple,
  16. Union,
  17. )
  18. import ray
  19. import ray.exceptions
  20. from ray.experimental.channel.accelerator_context import AcceleratorContext
  21. from ray.experimental.channel.communicator import Communicator
  22. from ray.experimental.channel.communicator_handle import CommunicatorHandle
  23. from ray.experimental.channel.serialization_context import _SerializationContext
  24. from ray.util.annotations import DeveloperAPI, PublicAPI
  25. # The context singleton on this process.
  26. _default_context: "Optional[ChannelContext]" = None
  27. _context_lock = threading.Lock()
  28. if TYPE_CHECKING:
  29. import torch
  30. def retry_and_check_interpreter_exit(f: Callable[[], None]) -> bool:
  31. """This function is only useful when f contains channel read/write.
  32. Keep retrying channel read/write inside `f` and check if interpreter exits.
  33. It is important in case the read/write happens in a separate thread pool.
  34. See https://github.com/ray-project/ray/pull/47702
  35. f should a function that doesn't receive any input and return nothing.
  36. """
  37. exiting = False
  38. while True:
  39. try:
  40. f()
  41. break
  42. except ray.exceptions.RayChannelTimeoutError:
  43. if sys.is_finalizing():
  44. # Interpreter exits. We should ignore the error and
  45. # stop reading so that the thread can join.
  46. exiting = True
  47. break
  48. return exiting
  49. # Holds the input arguments for Compiled Graph
  50. @PublicAPI(stability="alpha")
  51. class CompiledDAGArgs(NamedTuple):
  52. args: Tuple[Any, ...]
  53. kwargs: Dict[str, Any]
  54. @PublicAPI(stability="alpha")
  55. class ChannelOutputType:
  56. def register_custom_serializer(self) -> None:
  57. """
  58. Register any custom serializers needed to pass data of this type. This
  59. method should be run on the reader(s) and writer of a channel, which
  60. are the driver and/or Ray actors.
  61. NOTE: When custom serializers are registered with Ray, the registered
  62. deserializer is shipped with the serialized value and used on the
  63. receiving end. Therefore, the deserializer function should *not*
  64. capture state that is meant to be worker-local, such as the worker's
  65. default device. Instead, these should be extracted from the
  66. worker-local _SerializationContext.
  67. """
  68. pass
  69. def create_channel(
  70. self,
  71. writer: Optional["ray.actor.ActorHandle"],
  72. reader_and_node_list: List[Tuple["ray.actor.ActorHandle", str]],
  73. driver_actor_id: Optional[str] = None,
  74. ) -> "ChannelInterface":
  75. """
  76. Instantiate a ChannelInterface class that can be used
  77. to pass data of this type.
  78. Args:
  79. writer: The actor that may write to the channel. None signifies the driver.
  80. reader_and_node_list: A list of tuples, where each tuple contains a reader
  81. actor handle and the node ID where the actor is located.
  82. driver_actor_id: If this is a CompositeChannel that is read by a driver and
  83. that driver is an actual actor, this will be the actor ID of that
  84. driver actor.
  85. Returns:
  86. A ChannelInterface that can be used to pass data
  87. of this type.
  88. """
  89. raise NotImplementedError
  90. def requires_accelerator(self) -> bool:
  91. # By default, channels do not require accelerator.
  92. return False
  93. def get_custom_communicator(self) -> Optional[Communicator]:
  94. """
  95. Return the custom communicator group if one is specified.
  96. """
  97. return None
  98. def set_communicator_id(self, group_id: str) -> None:
  99. raise NotImplementedError
  100. @DeveloperAPI
  101. @dataclass
  102. class ChannelContext:
  103. serialization_context = _SerializationContext()
  104. _torch_available: Optional[bool] = None
  105. _torch_device: Optional["torch.device"] = None
  106. _current_stream: Optional["torch.cuda.Stream"] = None
  107. def __init__(self):
  108. # Used for the torch.Tensor accelerator transport.
  109. self.communicators: Dict[str, "Communicator"] = {}
  110. # Used for driver process to store actors in the communicator.
  111. self.communicator_handles: Dict[str, "CommunicatorHandle"] = {}
  112. @staticmethod
  113. def get_current() -> "ChannelContext":
  114. """Get or create a singleton context.
  115. If the context has not yet been created in this process, it will be
  116. initialized with default settings.
  117. """
  118. global _default_context
  119. with _context_lock:
  120. if _default_context is None:
  121. _default_context = ChannelContext()
  122. return _default_context
  123. @property
  124. def torch_available(self) -> bool:
  125. """
  126. Check if torch package is available.
  127. """
  128. if self._torch_available is not None:
  129. return self._torch_available
  130. try:
  131. import torch # noqa: F401
  132. except ImportError:
  133. self._torch_available = False
  134. return False
  135. self._torch_available = True
  136. return True
  137. @property
  138. def torch_device(self) -> "torch.device":
  139. if self._torch_device is None:
  140. self._torch_device = AcceleratorContext.get().get_accelerator_devices()[0]
  141. return self._torch_device
  142. def set_torch_device(self, device: "torch.device"):
  143. self._torch_device = device
  144. @PublicAPI(stability="alpha")
  145. class ChannelInterface:
  146. """
  147. Abstraction for a transport between a writer actor and some number of
  148. reader actors.
  149. """
  150. def __init__(
  151. self,
  152. writer: Optional[ray.actor.ActorHandle],
  153. readers: List[Optional[ray.actor.ActorHandle]],
  154. typ: Optional["ChannelOutputType"],
  155. ):
  156. """
  157. Create a channel that can be read and written by a Ray driver or actor.
  158. Args:
  159. writer: The actor that may write to the channel. None signifies the driver.
  160. readers: The actors that may read from the channel. None signifies
  161. the driver.
  162. typ: Type information about the values passed through the channel.
  163. """
  164. pass
  165. def ensure_registered_as_writer(self):
  166. """
  167. Check whether the process is a valid writer. This method must be idempotent.
  168. """
  169. raise NotImplementedError
  170. def ensure_registered_as_reader(self):
  171. """
  172. Check whether the process is a valid reader. This method must be idempotent.
  173. """
  174. raise NotImplementedError
  175. def write(self, value: Any, timeout: Optional[float] = None) -> None:
  176. """
  177. Write a value to the channel.
  178. Blocks if there are still pending readers for the previous value. The
  179. writer may not write again until the specified number of readers have
  180. read the value.
  181. Args:
  182. value: The value to write.
  183. timeout: The maximum time in seconds to wait to write the value.
  184. None means using default timeout, 0 means immediate timeout
  185. (immediate success or timeout without blocking), -1 means
  186. infinite timeout (block indefinitely).
  187. """
  188. raise NotImplementedError
  189. def read(self, timeout: Optional[float] = None) -> Any:
  190. """
  191. Read the latest value from the channel. This call will block until a
  192. value is available to read.
  193. Subsequent calls to read() may *block* if the deserialized object is
  194. zero-copy (e.g., bytes or a numpy array) *and* the object is still in scope.
  195. Args:
  196. timeout: The maximum time in seconds to wait to read the value.
  197. None means using default timeout, 0 means immediate timeout
  198. (immediate success or timeout without blocking), -1 means
  199. infinite timeout (block indefinitely).
  200. Returns:
  201. Any: The deserialized value. If the deserialized value is an
  202. Exception, it will be returned directly instead of being raised.
  203. """
  204. raise NotImplementedError
  205. def close(self) -> None:
  206. """
  207. Close this channel. This method must not block and it must be made
  208. idempotent. Any existing values in the channel may be lost after the
  209. channel is closed.
  210. """
  211. raise NotImplementedError
  212. # Interfaces for channel I/O.
  213. @DeveloperAPI
  214. class ReaderInterface:
  215. def __init__(
  216. self,
  217. input_channels: List[ChannelInterface],
  218. ):
  219. assert isinstance(input_channels, list)
  220. for chan in input_channels:
  221. assert isinstance(chan, ChannelInterface)
  222. self._input_channels = input_channels
  223. self._closed = False
  224. self._num_reads = 0
  225. # A list of channels that were not read in the last `read` call
  226. # because the reader returned immediately when a RayTaskError was found.
  227. # These channels must be consumed before the next read to avoid reading
  228. # stale data remaining from the last read.
  229. self._leftover_channels: List[ChannelInterface] = []
  230. def get_num_reads(self) -> int:
  231. return self._num_reads
  232. def start(self):
  233. raise NotImplementedError
  234. def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
  235. """
  236. Read a list of values from this reader.
  237. Args:
  238. timeout: The maximum time in seconds to wait for reading.
  239. None means using default timeout which is infinite, 0 means immediate
  240. timeout (immediate success or timeout without blocking), -1 means
  241. infinite timeout (block indefinitely).
  242. """
  243. raise NotImplementedError
  244. def read(self, timeout: Optional[float] = None) -> List[Any]:
  245. """
  246. Read from this reader.
  247. Args:
  248. timeout: The maximum time in seconds to wait for reading.
  249. None means using default timeout, 0 means immediate timeout
  250. (immediate success or timeout without blocking), -1 means
  251. infinite timeout (block indefinitely).
  252. """
  253. assert (
  254. timeout is None or timeout >= 0 or timeout == -1
  255. ), "Timeout must be non-negative or -1."
  256. outputs = self._read_list(timeout)
  257. self._num_reads += 1
  258. return outputs
  259. def close(self) -> None:
  260. self._closed = True
  261. for channel in self._input_channels:
  262. channel.close()
  263. def _consume_leftover_channels_if_needed(
  264. self, timeout: Optional[float] = None
  265. ) -> None:
  266. # Consume the channels that were not read in the last `read` call because a
  267. # RayTaskError was returned from another channel. If we don't do this, the
  268. # read operation will read stale versions of the object refs.
  269. #
  270. # If a RayTaskError is returned from a leftover channel, it will be ignored.
  271. # If a read operation times out, a RayChannelTimeoutError exception will be
  272. # raised.
  273. #
  274. # TODO(kevin85421): Currently, a DAG with NCCL channels and fast fail enabled
  275. # may not be reusable. Revisit this in the future.
  276. for c in self._leftover_channels:
  277. start_time = time.monotonic()
  278. c.read(timeout)
  279. if timeout is not None:
  280. timeout -= time.monotonic() - start_time
  281. timeout = max(timeout, 0)
  282. self._leftover_channels = []
  283. @DeveloperAPI
  284. class SynchronousReader(ReaderInterface):
  285. def __init__(
  286. self,
  287. input_channels: List[ChannelInterface],
  288. ):
  289. super().__init__(input_channels)
  290. def start(self):
  291. pass
  292. def _read_list(self, timeout: Optional[float] = None) -> List[Any]:
  293. self._consume_leftover_channels_if_needed(timeout)
  294. # We don't update `remaining_timeout` here because in the worst case,
  295. # consuming leftover channels requires reading all `_input_channels`,
  296. # which users expect to complete within the original `timeout`. Updating
  297. # `remaining_timeout` could cause unexpected timeouts in subsequent read
  298. # operations.
  299. # It is a special case that `timeout` is set to 0, which means
  300. # read once for each channel.
  301. is_zero_timeout = timeout == 0
  302. results = [None for _ in range(len(self._input_channels))]
  303. if timeout is None or timeout == -1:
  304. timeout = float("inf")
  305. timeout_point = time.monotonic() + timeout
  306. remaining_timeout = timeout
  307. from ray.dag import DAGContext
  308. ctx = DAGContext.get_current()
  309. iteration_timeout = ctx.read_iteration_timeout
  310. # Iterate over the input channels with a shorter timeout for each iteration
  311. # to detect RayTaskError early and fail fast.
  312. done_channels = set()
  313. while len(done_channels) < len(self._input_channels):
  314. for i, c in enumerate(self._input_channels):
  315. if c in done_channels:
  316. continue
  317. try:
  318. result = c.read(min(remaining_timeout, iteration_timeout))
  319. results[i] = result
  320. done_channels.add(c)
  321. if isinstance(result, ray.exceptions.RayTaskError):
  322. # If we raise an exception immediately, it will be considered
  323. # as a system error which will cause the execution loop to
  324. # exit. Hence, return immediately and let `_process_return_vals`
  325. # handle the exception.
  326. #
  327. # Return a list of RayTaskError so that the caller will not
  328. # get an undefined partial result.
  329. self._leftover_channels = [
  330. c for c in self._input_channels if c not in done_channels
  331. ]
  332. return [result for _ in range(len(self._input_channels))]
  333. except ray.exceptions.RayChannelTimeoutError as e:
  334. remaining_timeout = max(timeout_point - time.monotonic(), 0)
  335. if remaining_timeout == 0:
  336. raise e
  337. continue
  338. remaining_timeout = max(timeout_point - time.monotonic(), 0)
  339. if remaining_timeout == 0 and not is_zero_timeout:
  340. raise ray.exceptions.RayChannelTimeoutError(
  341. f"Cannot read all channels within {timeout} seconds"
  342. )
  343. return results
  344. def release_channel_buffers(self, timeout: Optional[float] = None) -> None:
  345. for c in self._input_channels:
  346. start_time = time.monotonic()
  347. assert hasattr(
  348. c, "release_buffer"
  349. ), "release_buffer() is only supported for shared memory channel "
  350. "(e.g., Channel, BufferedSharedMemoryChannel, CompositeChannel) "
  351. "and used between the last actor and the driver, but got a channel"
  352. f" of type {type(c)}."
  353. c.release_buffer(timeout)
  354. if timeout is not None:
  355. timeout -= time.monotonic() - start_time
  356. timeout = max(timeout, 0)
  357. @DeveloperAPI
  358. class AwaitableBackgroundReader(ReaderInterface):
  359. """
  360. Asyncio-compatible channel reader.
  361. The reader is constructed with an async queue of futures whose values it
  362. will fulfill. It uses a threadpool to execute the blocking calls to read
  363. from the input channel(s).
  364. """
  365. def __init__(
  366. self,
  367. input_channels: List[ChannelInterface],
  368. fut_queue: asyncio.Queue,
  369. ):
  370. super().__init__(input_channels)
  371. self._fut_queue = fut_queue
  372. self._background_task = None
  373. self._background_task_executor = concurrent.futures.ThreadPoolExecutor(
  374. max_workers=1, thread_name_prefix="channel.AwaitableBackgroundReader"
  375. )
  376. def start(self):
  377. self._background_task = asyncio.ensure_future(self.run())
  378. def _run(self):
  379. # Give it a default timeout 60 seconds to release the buffers
  380. # of the channels that were not read in the last `read` call.
  381. self._consume_leftover_channels_if_needed(60)
  382. results = [None for _ in range(len(self._input_channels))]
  383. from ray.dag import DAGContext
  384. ctx = DAGContext.get_current()
  385. iteration_timeout = ctx.read_iteration_timeout
  386. done_channels = set()
  387. while len(done_channels) < len(self._input_channels):
  388. for i, c in enumerate(self._input_channels):
  389. if c in done_channels:
  390. continue
  391. try:
  392. result = c.read(iteration_timeout)
  393. results[i] = result
  394. done_channels.add(c)
  395. if isinstance(result, ray.exceptions.RayTaskError):
  396. self._leftover_channels = [
  397. c for c in self._input_channels if c not in done_channels
  398. ]
  399. return [result for _ in range(len(self._input_channels))]
  400. except ray.exceptions.RayChannelTimeoutError:
  401. pass
  402. if sys.is_finalizing():
  403. return results
  404. return results
  405. async def run(self):
  406. loop = asyncio.get_running_loop()
  407. while not self._closed:
  408. res, fut = await asyncio.gather(
  409. loop.run_in_executor(self._background_task_executor, self._run),
  410. self._fut_queue.get(),
  411. return_exceptions=True,
  412. )
  413. # Set the result on the main thread.
  414. fut.set_result(res)
  415. # NOTE(swang): If the object is zero-copy deserialized, then it
  416. # will stay in scope as long as ret and the future are in scope.
  417. # Therefore, we must delete both here after fulfilling the future.
  418. del res
  419. del fut
  420. def close(self):
  421. super().close()
  422. self._background_task_executor.shutdown(cancel_futures=True)
  423. self._background_task.cancel()
  424. @DeveloperAPI
  425. class WriterInterface:
  426. def __init__(
  427. self,
  428. output_channels: List[ChannelInterface],
  429. output_idxs: List[Optional[Union[int, str]]],
  430. is_input=False,
  431. ):
  432. """
  433. Initialize the writer.
  434. Args:
  435. output_channels: The output channels to write to.
  436. output_idxs: The indices of the values to write to each channel.
  437. This has the same length as `output_channels`. If `is_input` is True,
  438. the index can be an integer or a string to retrieve the corresponding
  439. value from `args` or `kwargs` in the DAG's input. If `is_input`
  440. is False, the entire value is written if the index is None. Otherwise,
  441. the value at the specified index in the tuple is written.
  442. is_input: Whether the writer is DAG input writer or not.
  443. """
  444. assert len(output_channels) == len(output_idxs)
  445. self._output_channels = output_channels
  446. self._output_idxs = output_idxs
  447. self._closed = False
  448. self._num_writes = 0
  449. self._is_input = is_input
  450. def get_num_writes(self) -> int:
  451. return self._num_writes
  452. def start(self):
  453. raise NotImplementedError()
  454. def write(self, val: Any, timeout: Optional[float] = None) -> None:
  455. """
  456. Write the value.
  457. Args:
  458. timeout: The maximum time in seconds to wait for writing. 0 means
  459. immediate timeout (immediate success or timeout without blocking).
  460. -1 and None mean infinite timeout (blocks indefinitely).
  461. """
  462. raise NotImplementedError()
  463. def close(self) -> None:
  464. self._closed = True
  465. for channel in self._output_channels:
  466. channel.close()
  467. def _adapt(raw_args: Any, key: Optional[Union[int, str]], is_input: bool):
  468. """
  469. Adapt the raw arguments to the key. If `is_input` is True, this method will
  470. retrieve the value from the input data for an InputAttributeNode. Otherwise, it
  471. will retrieve either a partial value or the entire value from the output of
  472. a ClassMethodNode.
  473. Args:
  474. raw_args: The raw arguments to adapt.
  475. key: The key to adapt.
  476. is_input: Whether the writer is DAG input writer or not.
  477. """
  478. if is_input:
  479. if not isinstance(raw_args, CompiledDAGArgs):
  480. # Fast path for a single input.
  481. return raw_args
  482. else:
  483. args = raw_args.args
  484. kwargs = raw_args.kwargs
  485. if isinstance(key, int):
  486. return args[key]
  487. else:
  488. return kwargs[key]
  489. else:
  490. if key is not None:
  491. return raw_args[key]
  492. else:
  493. return raw_args
  494. @DeveloperAPI
  495. class SynchronousWriter(WriterInterface):
  496. def start(self):
  497. for channel in self._output_channels:
  498. channel.ensure_registered_as_writer()
  499. def write(self, val: Any, timeout: Optional[float] = None) -> None:
  500. # If it is an exception, there's only 1 return value.
  501. # We have to send the same data to all channels.
  502. if isinstance(val, Exception):
  503. if len(self._output_channels) > 1:
  504. val = tuple(val for _ in range(len(self._output_channels)))
  505. if not self._is_input:
  506. if len(self._output_channels) > 1:
  507. if not isinstance(val, tuple):
  508. raise ValueError(
  509. f"Expected a tuple of {len(self._output_channels)} outputs, "
  510. f"but got {type(val)}"
  511. )
  512. if len(val) != len(self._output_channels):
  513. raise ValueError(
  514. f"Expected {len(self._output_channels)} outputs, but got "
  515. f"{len(val)} outputs"
  516. )
  517. for i, channel in enumerate(self._output_channels):
  518. idx = self._output_idxs[i]
  519. val_i = _adapt(val, idx, self._is_input)
  520. channel.write(val_i, timeout)
  521. self._num_writes += 1
  522. @DeveloperAPI
  523. class AwaitableBackgroundWriter(WriterInterface):
  524. def __init__(
  525. self,
  526. output_channels: List[ChannelInterface],
  527. output_idxs: List[Optional[Union[int, str]]],
  528. is_input=False,
  529. ):
  530. super().__init__(output_channels, output_idxs, is_input=is_input)
  531. self._queue = asyncio.Queue()
  532. self._background_task = None
  533. self._background_task_executor = concurrent.futures.ThreadPoolExecutor(
  534. max_workers=1, thread_name_prefix="channel.AwaitableBackgroundWriter"
  535. )
  536. def start(self):
  537. for channel in self._output_channels:
  538. channel.ensure_registered_as_writer()
  539. self._background_task = asyncio.ensure_future(self.run())
  540. def _run(self, res):
  541. if not self._is_input:
  542. if len(self._output_channels) > 1:
  543. if not isinstance(res, tuple):
  544. raise ValueError(
  545. f"Expected a tuple of {len(self._output_channels)} outputs, "
  546. f"but got {type(res)}"
  547. )
  548. if len(res) != len(self._output_channels):
  549. raise ValueError(
  550. f"Expected {len(self._output_channels)} outputs, but got "
  551. f"{len(res)} outputs"
  552. )
  553. for i, channel in enumerate(self._output_channels):
  554. idx = self._output_idxs[i]
  555. res_i = _adapt(res, idx, self._is_input)
  556. exiting = retry_and_check_interpreter_exit(
  557. lambda: channel.write(res_i, timeout=1)
  558. )
  559. if exiting:
  560. break
  561. async def run(self):
  562. loop = asyncio.get_event_loop()
  563. while True:
  564. res = await self._queue.get()
  565. await loop.run_in_executor(self._background_task_executor, self._run, res)
  566. async def write(self, val: Any) -> None:
  567. if self._closed:
  568. raise RuntimeError("DAG execution cancelled")
  569. await self._queue.put(val)
  570. self._num_writes += 1
  571. def close(self):
  572. self._background_task.cancel()
  573. super().close()