import uuid from typing import Any, Optional from ray.experimental.channel.common import ChannelInterface class CachedChannel(ChannelInterface): """ CachedChannel wraps an inner channel and caches the data read from it until `num_reads` reads have completed. If inner channel is None, the data is written to serialization context and retrieved from there. This is useful when passing data within the same actor and a shared memory channel can be avoided. Args: num_reads: The number of reads from this channel that must happen before writing again. Readers must be methods of the same actor. inner_channel: The inner channel to cache data from. If None, the data is read from the serialization context. _channel_id: The unique ID for the channel. If None, a new ID is generated. """ def __init__( self, num_reads: int, inner_channel: Optional[ChannelInterface] = None, _channel_id: Optional[str] = None, ): assert num_reads > 0, "num_reads must be greater than 0." self._num_reads = num_reads self._inner_channel = inner_channel # Generate a unique ID for the channel. The writer and reader will use # this ID to store and retrieve data from the _SerializationContext. self._channel_id = _channel_id if self._channel_id is None: self._channel_id = str(uuid.uuid4()) def ensure_registered_as_writer(self) -> None: if self._inner_channel is not None: self._inner_channel.ensure_registered_as_writer() def ensure_registered_as_reader(self) -> None: if self._inner_channel is not None: self._inner_channel.ensure_registered_as_reader() def __reduce__(self): return CachedChannel, ( self._num_reads, self._inner_channel, self._channel_id, ) def __str__(self) -> str: return ( f"CachedChannel(channel_id={self._channel_id}, " f"num_reads={self._num_reads}), " f"inner_channel={self._inner_channel})" ) def write(self, value: Any, timeout: Optional[float] = None): self.ensure_registered_as_writer() # TODO: better organize the imports from ray.experimental.channel import ChannelContext if self._inner_channel is not None: self._inner_channel.write(value, timeout) return # Otherwise no need to check timeout as the operation is non-blocking. # Because both the reader and writer are in the same worker process, # we can directly store the data in the context instead of storing # it in the channel object. This removes the serialization overhead of `value`. ctx = ChannelContext.get_current().serialization_context ctx.set_data(self._channel_id, value, self._num_reads) def read(self, timeout: Optional[float] = None) -> Any: self.ensure_registered_as_reader() # TODO: better organize the imports from ray.experimental.channel import ChannelContext ctx = ChannelContext.get_current().serialization_context if ctx.has_data(self._channel_id): # No need to check timeout as the operation is non-blocking. return ctx.get_data(self._channel_id) assert ( self._inner_channel is not None ), "Cannot read from the serialization context while inner channel is None." value = self._inner_channel.read(timeout) ctx.set_data(self._channel_id, value, self._num_reads) # NOTE: Currently we make a contract with Compiled Graph users that the # channel results should not be mutated by the actor methods. # When the user needs to modify the channel results, they should # make a copy of the channel results and modify the copy. # This is the same contract as used in IntraProcessChannel. # This contract is NOT enforced right now in either case. # TODO(rui): introduce a flag to control the behavior: # for example, by default we make a deep copy of the channel # result, but the user can turn off the deep copy for performance # improvements. # https://github.com/ray-project/ray/issues/47409 return ctx.get_data(self._channel_id) def close(self) -> None: from ray.experimental.channel import ChannelContext if self._inner_channel is not None: self._inner_channel.close() ctx = ChannelContext.get_current().serialization_context ctx.reset_data(self._channel_id)