intra_process_channel.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import uuid
  2. from typing import Any, Optional
  3. from ray.experimental.channel import ChannelContext
  4. from ray.experimental.channel.common import ChannelInterface
  5. from ray.util.annotations import PublicAPI
  6. @PublicAPI(stability="alpha")
  7. class IntraProcessChannel(ChannelInterface):
  8. """
  9. IntraProcessChannel is a channel for communication between two tasks in the same
  10. worker process. It writes data directly to the worker's _SerializationContext
  11. and reads data from the _SerializationContext to avoid the serialization
  12. overhead and the need for reading/writing from shared memory. Note that if the
  13. readers may mutate the data, users should deep copy the data themselves to avoid
  14. side effects.
  15. Args:
  16. num_readers: The number of readers that will read from this channel. Readers
  17. can be the same method of the same actor.
  18. """
  19. def __init__(
  20. self,
  21. num_readers,
  22. _channel_id: Optional[str] = None,
  23. ):
  24. # Generate a unique ID for the channel. The writer and reader will use
  25. # this ID to store and retrieve data from the _SerializationContext.
  26. self._channel_id = _channel_id
  27. self._num_readers = num_readers
  28. if self._channel_id is None:
  29. self._channel_id = str(uuid.uuid4())
  30. def ensure_registered_as_writer(self) -> None:
  31. pass
  32. def ensure_registered_as_reader(self) -> None:
  33. pass
  34. def __reduce__(self):
  35. return IntraProcessChannel, (
  36. self._num_readers,
  37. self._channel_id,
  38. )
  39. def __str__(self) -> str:
  40. return f"IntraProcessChannel(channel_id={self._channel_id})"
  41. def write(self, value: Any, timeout: Optional[float] = None):
  42. self.ensure_registered_as_writer()
  43. # No need to check timeout as the operation is non-blocking.
  44. # Because both the reader and writer are in the same worker process,
  45. # we can directly store the data in the context instead of storing
  46. # it in the channel object. This removes the serialization overhead of `value`.
  47. ctx = ChannelContext.get_current().serialization_context
  48. ctx.set_data(self._channel_id, value, self._num_readers)
  49. def read(self, timeout: Optional[float] = None, deserialize: bool = True) -> Any:
  50. self.ensure_registered_as_reader()
  51. assert deserialize, "Data passed from the actor to itself is never serialized"
  52. # No need to check timeout as the operation is non-blocking.
  53. ctx = ChannelContext.get_current().serialization_context
  54. return ctx.get_data(self._channel_id)
  55. def close(self) -> None:
  56. ctx = ChannelContext.get_current().serialization_context
  57. ctx.reset_data(self._channel_id)