cached_channel.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import uuid
  2. from typing import Any, Optional
  3. from ray.experimental.channel.common import ChannelInterface
  4. class CachedChannel(ChannelInterface):
  5. """
  6. CachedChannel wraps an inner channel and caches the data read from it until
  7. `num_reads` reads have completed. If inner channel is None, the data
  8. is written to serialization context and retrieved from there. This is useful
  9. when passing data within the same actor and a shared memory channel can be
  10. avoided.
  11. Args:
  12. num_reads: The number of reads from this channel that must happen before
  13. writing again. Readers must be methods of the same actor.
  14. inner_channel: The inner channel to cache data from. If None, the data is
  15. read from the serialization context.
  16. _channel_id: The unique ID for the channel. If None, a new ID is generated.
  17. """
  18. def __init__(
  19. self,
  20. num_reads: int,
  21. inner_channel: Optional[ChannelInterface] = None,
  22. _channel_id: Optional[str] = None,
  23. ):
  24. assert num_reads > 0, "num_reads must be greater than 0."
  25. self._num_reads = num_reads
  26. self._inner_channel = inner_channel
  27. # Generate a unique ID for the channel. The writer and reader will use
  28. # this ID to store and retrieve data from the _SerializationContext.
  29. self._channel_id = _channel_id
  30. if self._channel_id is None:
  31. self._channel_id = str(uuid.uuid4())
  32. def ensure_registered_as_writer(self) -> None:
  33. if self._inner_channel is not None:
  34. self._inner_channel.ensure_registered_as_writer()
  35. def ensure_registered_as_reader(self) -> None:
  36. if self._inner_channel is not None:
  37. self._inner_channel.ensure_registered_as_reader()
  38. def __reduce__(self):
  39. return CachedChannel, (
  40. self._num_reads,
  41. self._inner_channel,
  42. self._channel_id,
  43. )
  44. def __str__(self) -> str:
  45. return (
  46. f"CachedChannel(channel_id={self._channel_id}, "
  47. f"num_reads={self._num_reads}), "
  48. f"inner_channel={self._inner_channel})"
  49. )
  50. def write(self, value: Any, timeout: Optional[float] = None):
  51. self.ensure_registered_as_writer()
  52. # TODO: better organize the imports
  53. from ray.experimental.channel import ChannelContext
  54. if self._inner_channel is not None:
  55. self._inner_channel.write(value, timeout)
  56. return
  57. # Otherwise no need to check timeout as the operation is non-blocking.
  58. # Because both the reader and writer are in the same worker process,
  59. # we can directly store the data in the context instead of storing
  60. # it in the channel object. This removes the serialization overhead of `value`.
  61. ctx = ChannelContext.get_current().serialization_context
  62. ctx.set_data(self._channel_id, value, self._num_reads)
  63. def read(self, timeout: Optional[float] = None) -> Any:
  64. self.ensure_registered_as_reader()
  65. # TODO: better organize the imports
  66. from ray.experimental.channel import ChannelContext
  67. ctx = ChannelContext.get_current().serialization_context
  68. if ctx.has_data(self._channel_id):
  69. # No need to check timeout as the operation is non-blocking.
  70. return ctx.get_data(self._channel_id)
  71. assert (
  72. self._inner_channel is not None
  73. ), "Cannot read from the serialization context while inner channel is None."
  74. value = self._inner_channel.read(timeout)
  75. ctx.set_data(self._channel_id, value, self._num_reads)
  76. # NOTE: Currently we make a contract with Compiled Graph users that the
  77. # channel results should not be mutated by the actor methods.
  78. # When the user needs to modify the channel results, they should
  79. # make a copy of the channel results and modify the copy.
  80. # This is the same contract as used in IntraProcessChannel.
  81. # This contract is NOT enforced right now in either case.
  82. # TODO(rui): introduce a flag to control the behavior:
  83. # for example, by default we make a deep copy of the channel
  84. # result, but the user can turn off the deep copy for performance
  85. # improvements.
  86. # https://github.com/ray-project/ray/issues/47409
  87. return ctx.get_data(self._channel_id)
  88. def close(self) -> None:
  89. from ray.experimental.channel import ChannelContext
  90. if self._inner_channel is not None:
  91. self._inner_channel.close()
  92. ctx = ChannelContext.get_current().serialization_context
  93. ctx.reset_data(self._channel_id)