gcs_pubsub.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import asyncio
  2. import logging
  3. import random
  4. from collections import deque
  5. from typing import List, Tuple
  6. import grpc
  7. from grpc import aio as aiogrpc
  8. import ray._private.gcs_utils as gcs_utils
  9. from ray._common.utils import get_or_create_event_loop
  10. from ray.core.generated import (
  11. gcs_pb2,
  12. gcs_service_pb2,
  13. gcs_service_pb2_grpc,
  14. pubsub_pb2,
  15. )
  16. logger = logging.getLogger(__name__)
  17. class _SubscriberBase:
  18. def __init__(self, worker_id: bytes = None):
  19. self._worker_id = worker_id
  20. # self._subscriber_id needs to match the binary format of a random
  21. # SubscriberID / UniqueID, which is 28 (kUniqueIDSize) random bytes.
  22. self._subscriber_id = bytes(bytearray(random.getrandbits(8) for _ in range(28)))
  23. self._last_batch_size = 0
  24. self._max_processed_sequence_id = 0
  25. self._publisher_id = b""
  26. # Batch size of the result from last poll. Used to indicate whether the
  27. # subscriber can keep up.
  28. @property
  29. def last_batch_size(self):
  30. return self._last_batch_size
  31. def _subscribe_request(self, channel):
  32. cmd = pubsub_pb2.Command(channel_type=channel, subscribe_message={})
  33. req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
  34. subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[cmd]
  35. )
  36. return req
  37. def _poll_request(self):
  38. return gcs_service_pb2.GcsSubscriberPollRequest(
  39. subscriber_id=self._subscriber_id,
  40. max_processed_sequence_id=self._max_processed_sequence_id,
  41. publisher_id=self._publisher_id,
  42. )
  43. def _unsubscribe_request(self, channels):
  44. req = gcs_service_pb2.GcsSubscriberCommandBatchRequest(
  45. subscriber_id=self._subscriber_id, sender_id=self._worker_id, commands=[]
  46. )
  47. for channel in channels:
  48. req.commands.append(
  49. pubsub_pb2.Command(channel_type=channel, unsubscribe_message={})
  50. )
  51. return req
  52. @staticmethod
  53. def _should_terminate_polling(e: grpc.RpcError) -> None:
  54. # Caller only expects polling to be terminated after deadline exceeded.
  55. if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
  56. return True
  57. # Could be a temporary connection issue. Suppress error.
  58. # TODO: reconnect GRPC channel?
  59. if e.code() == grpc.StatusCode.UNAVAILABLE:
  60. return True
  61. return False
  62. class _AioSubscriber(_SubscriberBase):
  63. """Async io subscriber to GCS.
  64. Usage example common to Aio subscribers:
  65. subscriber = GcsAioXxxSubscriber(address="...")
  66. await subscriber.subscribe()
  67. while running:
  68. ...... = await subscriber.poll()
  69. ......
  70. await subscriber.close()
  71. """
  72. def __init__(
  73. self,
  74. pubsub_channel_type,
  75. worker_id: bytes = None,
  76. address: str = None,
  77. channel: aiogrpc.Channel = None,
  78. ):
  79. super().__init__(worker_id)
  80. if address:
  81. assert channel is None, "address and channel cannot both be specified"
  82. channel = gcs_utils.create_gcs_channel(address, aio=True)
  83. else:
  84. assert channel is not None, "One of address and channel must be specified"
  85. # GRPC stub to GCS pubsub.
  86. self._stub = gcs_service_pb2_grpc.InternalPubSubGcsServiceStub(channel)
  87. # Type of the channel.
  88. self._channel = pubsub_channel_type
  89. # A queue of received PubMessage.
  90. self._queue = deque()
  91. # Indicates whether the subscriber has closed.
  92. self._close = asyncio.Event()
  93. async def subscribe(self) -> None:
  94. """Registers a subscription for the subscriber's channel type.
  95. Before the registration, published messages in the channel will not be
  96. saved for the subscriber.
  97. """
  98. if self._close.is_set():
  99. return
  100. req = self._subscribe_request(self._channel)
  101. await self._stub.GcsSubscriberCommandBatch(req, timeout=30)
  102. async def _poll_call(self, req, timeout=None):
  103. # Wrap GRPC _AioCall as a coroutine.
  104. return await self._stub.GcsSubscriberPoll(req, timeout=timeout)
  105. async def _poll(self, timeout=None) -> None:
  106. while len(self._queue) == 0:
  107. req = self._poll_request()
  108. poll = get_or_create_event_loop().create_task(
  109. self._poll_call(req, timeout=timeout)
  110. )
  111. close = get_or_create_event_loop().create_task(self._close.wait())
  112. done, others = await asyncio.wait(
  113. [poll, close], timeout=timeout, return_when=asyncio.FIRST_COMPLETED
  114. )
  115. # Cancel the other task if needed to prevent memory leak.
  116. other_task = others.pop()
  117. if not other_task.done():
  118. other_task.cancel()
  119. if poll not in done or close in done:
  120. # Request timed out or subscriber closed.
  121. break
  122. try:
  123. self._last_batch_size = len(poll.result().pub_messages)
  124. if poll.result().publisher_id != self._publisher_id:
  125. if self._publisher_id != "":
  126. logger.debug(
  127. f"replied publisher_id {poll.result().publisher_id}"
  128. f"different from {self._publisher_id}, this should "
  129. "only happens during gcs failover."
  130. )
  131. self._publisher_id = poll.result().publisher_id
  132. self._max_processed_sequence_id = 0
  133. for msg in poll.result().pub_messages:
  134. if msg.sequence_id <= self._max_processed_sequence_id:
  135. logger.warning(f"Ignoring out of order message {msg}")
  136. continue
  137. self._max_processed_sequence_id = msg.sequence_id
  138. self._queue.append(msg)
  139. except grpc.RpcError as e:
  140. if self._should_terminate_polling(e):
  141. return
  142. raise
  143. async def close(self) -> None:
  144. """Closes the subscriber and its active subscription."""
  145. # Mark close to terminate inflight polling and prevent future requests.
  146. if self._close.is_set():
  147. return
  148. self._close.set()
  149. req = self._unsubscribe_request(channels=[self._channel])
  150. try:
  151. await self._stub.GcsSubscriberCommandBatch(req, timeout=5)
  152. except Exception:
  153. pass
  154. self._stub = None
  155. class GcsAioResourceUsageSubscriber(_AioSubscriber):
  156. def __init__(
  157. self,
  158. worker_id: bytes = None,
  159. address: str = None,
  160. channel: grpc.Channel = None,
  161. ):
  162. super().__init__(
  163. pubsub_pb2.RAY_NODE_RESOURCE_USAGE_CHANNEL, worker_id, address, channel
  164. )
  165. async def poll(self, timeout=None) -> Tuple[bytes, str]:
  166. """Polls for new resource usage message.
  167. Returns:
  168. A tuple of string reporter ID and resource usage json string.
  169. """
  170. await self._poll(timeout=timeout)
  171. return self._pop_resource_usage(self._queue)
  172. @staticmethod
  173. def _pop_resource_usage(queue):
  174. if len(queue) == 0:
  175. return None, None
  176. msg = queue.popleft()
  177. return msg.key_id.decode(), msg.node_resource_usage_message.json
  178. class GcsAioActorSubscriber(_AioSubscriber):
  179. def __init__(
  180. self,
  181. worker_id: bytes = None,
  182. address: str = None,
  183. channel: grpc.Channel = None,
  184. ):
  185. super().__init__(pubsub_pb2.GCS_ACTOR_CHANNEL, worker_id, address, channel)
  186. @property
  187. def queue_size(self):
  188. return len(self._queue)
  189. async def poll(
  190. self, batch_size, timeout=None
  191. ) -> List[Tuple[bytes, gcs_pb2.ActorTableData]]:
  192. """Polls for new actor message.
  193. Returns:
  194. A list of tuples of binary actor ID and actor table data.
  195. """
  196. await self._poll(timeout=timeout)
  197. return self._pop_actors(self._queue, batch_size=batch_size)
  198. @staticmethod
  199. def _pop_actors(queue, batch_size):
  200. if len(queue) == 0:
  201. return []
  202. popped = 0
  203. msgs = []
  204. while len(queue) > 0 and popped < batch_size:
  205. msg = queue.popleft()
  206. msgs.append((msg.key_id, msg.actor_message))
  207. popped += 1
  208. return msgs
  209. class GcsAioNodeInfoSubscriber(_AioSubscriber):
  210. def __init__(
  211. self,
  212. worker_id: bytes = None,
  213. address: str = None,
  214. channel: grpc.Channel = None,
  215. ):
  216. super().__init__(pubsub_pb2.GCS_NODE_INFO_CHANNEL, worker_id, address, channel)
  217. async def poll(
  218. self, batch_size, timeout=None
  219. ) -> List[Tuple[bytes, gcs_pb2.GcsNodeInfo]]:
  220. """Polls for new node info message.
  221. Returns:
  222. A list of tuples of (node_id, GcsNodeInfo).
  223. """
  224. await self._poll(timeout=timeout)
  225. return self._pop_node_infos(self._queue, batch_size=batch_size)
  226. @staticmethod
  227. def _pop_node_infos(queue, batch_size):
  228. if len(queue) == 0:
  229. return []
  230. popped = 0
  231. msgs = []
  232. while len(queue) > 0 and popped < batch_size:
  233. msg = queue.popleft()
  234. msgs.append((msg.key_id, msg.node_info_message))
  235. popped += 1
  236. return msgs