| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- import uuid
- from typing import Any, Optional
- from ray.experimental.channel.common import ChannelInterface
- class CachedChannel(ChannelInterface):
- """
- CachedChannel wraps an inner channel and caches the data read from it until
- `num_reads` reads have completed. If inner channel is None, the data
- is written to serialization context and retrieved from there. This is useful
- when passing data within the same actor and a shared memory channel can be
- avoided.
- Args:
- num_reads: The number of reads from this channel that must happen before
- writing again. Readers must be methods of the same actor.
- inner_channel: The inner channel to cache data from. If None, the data is
- read from the serialization context.
- _channel_id: The unique ID for the channel. If None, a new ID is generated.
- """
- def __init__(
- self,
- num_reads: int,
- inner_channel: Optional[ChannelInterface] = None,
- _channel_id: Optional[str] = None,
- ):
- assert num_reads > 0, "num_reads must be greater than 0."
- self._num_reads = num_reads
- self._inner_channel = inner_channel
- # Generate a unique ID for the channel. The writer and reader will use
- # this ID to store and retrieve data from the _SerializationContext.
- self._channel_id = _channel_id
- if self._channel_id is None:
- self._channel_id = str(uuid.uuid4())
- def ensure_registered_as_writer(self) -> None:
- if self._inner_channel is not None:
- self._inner_channel.ensure_registered_as_writer()
- def ensure_registered_as_reader(self) -> None:
- if self._inner_channel is not None:
- self._inner_channel.ensure_registered_as_reader()
- def __reduce__(self):
- return CachedChannel, (
- self._num_reads,
- self._inner_channel,
- self._channel_id,
- )
- def __str__(self) -> str:
- return (
- f"CachedChannel(channel_id={self._channel_id}, "
- f"num_reads={self._num_reads}), "
- f"inner_channel={self._inner_channel})"
- )
- def write(self, value: Any, timeout: Optional[float] = None):
- self.ensure_registered_as_writer()
- # TODO: better organize the imports
- from ray.experimental.channel import ChannelContext
- if self._inner_channel is not None:
- self._inner_channel.write(value, timeout)
- return
- # Otherwise no need to check timeout as the operation is non-blocking.
- # Because both the reader and writer are in the same worker process,
- # we can directly store the data in the context instead of storing
- # it in the channel object. This removes the serialization overhead of `value`.
- ctx = ChannelContext.get_current().serialization_context
- ctx.set_data(self._channel_id, value, self._num_reads)
- def read(self, timeout: Optional[float] = None) -> Any:
- self.ensure_registered_as_reader()
- # TODO: better organize the imports
- from ray.experimental.channel import ChannelContext
- ctx = ChannelContext.get_current().serialization_context
- if ctx.has_data(self._channel_id):
- # No need to check timeout as the operation is non-blocking.
- return ctx.get_data(self._channel_id)
- assert (
- self._inner_channel is not None
- ), "Cannot read from the serialization context while inner channel is None."
- value = self._inner_channel.read(timeout)
- ctx.set_data(self._channel_id, value, self._num_reads)
- # NOTE: Currently we make a contract with Compiled Graph users that the
- # channel results should not be mutated by the actor methods.
- # When the user needs to modify the channel results, they should
- # make a copy of the channel results and modify the copy.
- # This is the same contract as used in IntraProcessChannel.
- # This contract is NOT enforced right now in either case.
- # TODO(rui): introduce a flag to control the behavior:
- # for example, by default we make a deep copy of the channel
- # result, but the user can turn off the deep copy for performance
- # improvements.
- # https://github.com/ray-project/ray/issues/47409
- return ctx.get_data(self._channel_id)
- def close(self) -> None:
- from ray.experimental.channel import ChannelContext
- if self._inner_channel is not None:
- self._inner_channel.close()
- ctx = ChannelContext.get_current().serialization_context
- ctx.reset_data(self._channel_id)
|