| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- import asyncio
- import logging
- import random
- from collections import deque
- from typing import List, Tuple
- import grpc
- from grpc import aio as aiogrpc
- import ray._private.gcs_utils as gcs_utils
- from ray._common.utils import get_or_create_event_loop
- from ray.core.generated import (
- gcs_pb2,
- gcs_service_pb2,
- gcs_service_pb2_grpc,
- pubsub_pb2,
- )
- logger = logging.getLogger(__name__)
- class _SubscriberBase:
- def __init__(self, worker_id: bytes = None):
- self._worker_id = worker_id
- # self._subscriber_id needs to match the binary format of a random
- # SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
- self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28)))
- self._last_batch_size = 0
- self._max_processed_sequence_id = 0
- self._publisher_id = b""
- # Batch size of the result from last poll. Used to indicate whether the
- # subscriber can keep up.
- @property
- def last_batch_size(self):
- return self._last_batch_size
- def _subscribe_request(self, channel):
- cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
- req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
- subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[cmd]
- )
- return req
- def _poll_request(self):
- return gcs_service_pb2.GcsSubscriberPollRequest(
- subscriber_id=self._subscriber_id,
- max_processed_sequence_id=self._max_processed_sequence_id,
- publisher_id=self._publisher_id,
- )
- def _unsubscribe_request(self, channels):
- req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
- subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[]
- )
- for channel in channels:
- req.commands.append(
- pubsub_pb2.Command(channel_type=channel, unsubscribe_message={})
- )
- return req
- @staticmethod
- def _should_terminate_polling(e: grpc.RpcError) -> None:
- # Caller only expects polling to be terminated after deadline exceeded.
- if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
- return True
- # Could be a temporary connection issue. Suppress error.
- # TODO: reconnect GRPC channel?
- if e.code() == grpc.StatusCode.UNAVAILABLE:
- return True
- return False
- class _AioSubscriber(_SubscriberBase):
- """Async io subscriber to GCS.
- Usage example common to Aio subscribers:
- subscriber = GcsAioXxxSubscriber(address="...")
- await subscriber.subscribe()
- while running:
- ...... = await subscriber.poll()
- ......
- await subscriber.close()
- """
- def __init__(
- self,
- pubsub_channel_type,
- worker_id: bytes = None,
- address: str = None,
- channel: aiogrpc.Channel = None,
- ):
- super().__init__(worker_id)
- if address:
- assert channel is None, "address and channel cannot both be specified"
- channel = gcs_utils.create_gcs_channel(address, aio=True)
- else:
- assert channel is not None, "One of address and channel must be specified"
- # GRPC stub to GCS pubsub.
- self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
- # Type of the channel.
- self._channel = pubsub_channel_type
- # A queue of received PubMessage.
- self._queue = deque()
- # Indicates whether the subscriber has closed.
- self._close = asyncio.Event()
- async def subscribe(self) -> None:
- """Registers a subscription for the subscriber's channel type.
- Before the registration, published messages in the channel will not be
- saved for the subscriber.
- """
- if self._close.is_set():
- return
- req = self._subscribe_request(self._channel)
- await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
- async def _poll_call(self, req, timeout=None):
- # Wrap GRPC _AioCall as a coroutine.
- return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
- async def _poll(self, timeout=None) -> None:
- while len(self._queue) == 0:
- req = self._poll_request()
- poll = get_or_create_event_loop().create_task(
- self._poll_call(req, timeout=timeout)
- )
- close = get_or_create_event_loop().create_task(self._close.wait())
- done, others = await asyncio.wait(
- [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
- )
- # Cancel the other task if needed to prevent memory leak.
- other_task = others.pop()
- if not other_task.done():
- other_task.cancel()
- if poll not in done or close in done:
- # Request timed out or subscriber closed.
- break
- try:
- self._last_batch_size = len(poll.result().pub_messages)
- if poll.result().publisher_id != self._publisher_id:
- if self._publisher_id != "":
- logger.debug(
- f"replied publisher_id {poll.result().publisher_id}"
- f"different from {self._publisher_id}, this should "
- "only happens during gcs failover."
- )
- self._publisher_id = poll.result().publisher_id
- self._max_processed_sequence_id = 0
- for msg in poll.result().pub_messages:
- if msg.sequence_id <= self._max_processed_sequence_id:
- logger.warning(f"Ignoring out of order message {msg}")
- continue
- self._max_processed_sequence_id = msg.sequence_id
- self._queue.append(msg)
- except grpc.RpcError as e:
- if self._should_terminate_polling(e):
- return
- raise
- async def close(self) -> None:
- """Closes the subscriber and its active subscription."""
- # Mark close to terminate inflight polling and prevent future requests.
- if self._close.is_set():
- return
- self._close.set()
- req = self._unsubscribe_request(channels=[self._channel])
- try:
- await self._stub.GcsSubscriberCommandBatch(req, timeout=5)
- except Exception:
- pass
- self._stub = None
- class GcsAioResourceUsageSubscriber(_AioSubscriber):
- def __init__(
- self,
- worker_id: bytes = None,
- address: str = None,
- channel: grpc.Channel = None,
- ):
- super().__init__(
- pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, worker_id, address, channel
- )
- async def poll(self, timeout=None) -> Tuple[bytes, str]:
- """Polls for new resource usage message.
- Returns:
- A tuple of string reporter ID and resource usage json string.
- """
- await self._poll(timeout=timeout)
- return self._pop_resource_usage(self._queue)
- @staticmethod
- def _pop_resource_usage(queue):
- if len(queue) == 0:
- return None, None
- msg = queue.popleft()
- return msg.key_id.decode(), msg.node_resource_usage_message.json
- class GcsAioActorSubscriber(_AioSubscriber):
- def __init__(
- self,
- worker_id: bytes = None,
- address: str = None,
- channel: grpc.Channel = None,
- ):
- super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)
- @property
- def queue_size(self):
- return len(self._queue)
- async def poll(
- self, batch_size, timeout=None
- ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]:
- """Polls for new actor message.
- Returns:
- A list of tuples of binary actor ID and actor table data.
- """
- await self._poll(timeout=timeout)
- return self._pop_actors(self._queue, batch_size=batch_size)
- @staticmethod
- def _pop_actors(queue, batch_size):
- if len(queue) == 0:
- return []
- popped = 0
- msgs = []
- while len(queue) > 0 and popped < batch_size:
- msg = queue.popleft()
- msgs.append((msg.key_id, msg.actor_message))
- popped += 1
- return msgs
- class GcsAioNodeInfoSubscriber(_AioSubscriber):
- def __init__(
- self,
- worker_id: bytes = None,
- address: str = None,
- channel: grpc.Channel = None,
- ):
- super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel)
- async def poll(
- self, batch_size, timeout=None
- ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]:
- """Polls for new node info message.
- Returns:
- A list of tuples of (node_id, GcsNodeInfo).
- """
- await self._poll(timeout=timeout)
- return self._pop_node_infos(self._queue, batch_size=batch_size)
- @staticmethod
- def _pop_node_infos(queue, batch_size):
- if len(queue) == 0:
- return []
- popped = 0
- msgs = []
- while len(queue) > 0 and popped < batch_size:
- msg = queue.popleft()
- msgs.append((msg.key_id, msg.node_info_message))
- popped += 1
- return msgs
|