broker.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # This module provides broker clients for querying queue lengths from message brokers.
  2. # Adapted from Flower's broker.py (https://github.com/mher/flower/blob/master/flower/utils/broker.py)
  3. # with the following modification:
  4. # - Added close() method to BrokerBase and RedisBase for resource cleanup
  5. import json
  6. import logging
  7. import numbers
  8. import socket
  9. from urllib.parse import quote, unquote, urljoin, urlparse
  10. from tornado import httpclient, ioloop
  11. from ray.serve._private.constants import SERVE_LOGGER_NAME
  12. try:
  13. import redis
  14. except ImportError:
  15. redis = None
  16. logger = logging.getLogger(SERVE_LOGGER_NAME)
  17. class BrokerBase:
  18. def __init__(self, broker_url, *_, **__):
  19. purl = urlparse(broker_url)
  20. self.host = purl.hostname
  21. self.port = purl.port
  22. self.vhost = purl.path[1:]
  23. username = purl.username
  24. password = purl.password
  25. self.username = unquote(username) if username else username
  26. self.password = unquote(password) if password else password
  27. async def queues(self, names):
  28. raise NotImplementedError
  29. def close(self):
  30. """Close any open connections. Override in subclasses as needed."""
  31. pass
  32. class RabbitMQ(BrokerBase):
  33. def __init__(self, broker_url, http_api, io_loop=None, **__):
  34. super().__init__(broker_url)
  35. self.io_loop = io_loop or ioloop.IOLoop.instance()
  36. self.host = self.host or "localhost"
  37. self.port = self.port or 15672
  38. self.vhost = quote(self.vhost, "") or "/" if self.vhost != "/" else self.vhost
  39. self.username = self.username or "guest"
  40. self.password = self.password or "guest"
  41. if not http_api:
  42. http_api = f"http://{self.username}:{self.password}@{self.host}:{self.port}/api/{self.vhost}"
  43. try:
  44. self.validate_http_api(http_api)
  45. except ValueError:
  46. logger.error("Invalid broker api url: %s", http_api)
  47. self.http_api = http_api
  48. async def queues(self, names):
  49. url = urljoin(self.http_api, "queues/" + self.vhost)
  50. api_url = urlparse(self.http_api)
  51. username = unquote(api_url.username or "") or self.username
  52. password = unquote(api_url.password or "") or self.password
  53. http_client = httpclient.AsyncHTTPClient()
  54. try:
  55. response = await http_client.fetch(
  56. url,
  57. auth_username=username,
  58. auth_password=password,
  59. connect_timeout=1.0,
  60. request_timeout=2.0,
  61. validate_cert=False,
  62. )
  63. except (socket.error, httpclient.HTTPError) as e:
  64. logger.error("RabbitMQ management API call failed: %s", e)
  65. return []
  66. finally:
  67. http_client.close()
  68. if response.code == 200:
  69. info = json.loads(response.body.decode())
  70. return [x for x in info if x["name"] in names]
  71. response.rethrow()
  72. @classmethod
  73. def validate_http_api(cls, http_api):
  74. url = urlparse(http_api)
  75. if url.scheme not in ("http", "https"):
  76. raise ValueError(f"Invalid http api schema: {url.scheme}")
  77. class RedisBase(BrokerBase):
  78. DEFAULT_SEP = "\x06\x16"
  79. DEFAULT_PRIORITY_STEPS = [0, 3, 6, 9]
  80. def __init__(self, broker_url, *_, **kwargs):
  81. super().__init__(broker_url)
  82. self.redis = None
  83. if not redis:
  84. raise ImportError("redis library is required")
  85. broker_options = kwargs.get("broker_options", {})
  86. self.priority_steps = broker_options.get(
  87. "priority_steps", self.DEFAULT_PRIORITY_STEPS
  88. )
  89. self.sep = broker_options.get("sep", self.DEFAULT_SEP)
  90. self.broker_prefix = broker_options.get("global_keyprefix", "")
  91. def _q_for_pri(self, queue, pri):
  92. if pri not in self.priority_steps:
  93. raise ValueError("Priority not in priority steps")
  94. # pylint: disable=consider-using-f-string
  95. return "{0}{1}{2}".format(*((queue, self.sep, pri) if pri else (queue, "", "")))
  96. async def queues(self, names):
  97. queue_stats = []
  98. for name in names:
  99. priority_names = [
  100. self.broker_prefix + self._q_for_pri(name, pri)
  101. for pri in self.priority_steps
  102. ]
  103. queue_stats.append(
  104. {
  105. "name": name,
  106. "messages": sum((self.redis.llen(x) for x in priority_names)),
  107. }
  108. )
  109. return queue_stats
  110. def close(self):
  111. """Close the Redis connection."""
  112. if self.redis is not None:
  113. self.redis.close()
  114. self.redis = None
  115. class Redis(RedisBase):
  116. def __init__(self, broker_url, *args, **kwargs):
  117. super().__init__(broker_url, *args, **kwargs)
  118. self.host = self.host or "localhost"
  119. self.port = self.port or 6379
  120. self.vhost = self._prepare_virtual_host(self.vhost)
  121. self.redis = self._get_redis_client()
  122. def _prepare_virtual_host(self, vhost):
  123. if not isinstance(vhost, numbers.Integral):
  124. if not vhost or vhost == "/":
  125. vhost = 0
  126. elif vhost.startswith("/"):
  127. vhost = vhost[1:]
  128. try:
  129. vhost = int(vhost)
  130. except ValueError as exc:
  131. raise ValueError(
  132. f"Database is int between 0 and limit - 1, not {vhost}"
  133. ) from exc
  134. return vhost
  135. def _get_redis_client_args(self):
  136. return {
  137. "host": self.host,
  138. "port": self.port,
  139. "db": self.vhost,
  140. "username": self.username,
  141. "password": self.password,
  142. }
  143. def _get_redis_client(self):
  144. return redis.Redis(**self._get_redis_client_args())
  145. class RedisSentinel(RedisBase):
  146. def __init__(self, broker_url, *args, **kwargs):
  147. super().__init__(broker_url, *args, **kwargs)
  148. broker_options = kwargs.get("broker_options", {})
  149. broker_use_ssl = kwargs.get("broker_use_ssl", None)
  150. self.host = self.host or "localhost"
  151. self.port = self.port or 26379
  152. self.vhost = self._prepare_virtual_host(self.vhost)
  153. self.master_name = self._prepare_master_name(broker_options)
  154. self.redis = self._get_redis_client(broker_options, broker_use_ssl)
  155. def _prepare_virtual_host(self, vhost):
  156. if not isinstance(vhost, numbers.Integral):
  157. if not vhost or vhost == "/":
  158. vhost = 0
  159. elif vhost.startswith("/"):
  160. vhost = vhost[1:]
  161. try:
  162. vhost = int(vhost)
  163. except ValueError as exc:
  164. raise ValueError(
  165. f"Database is int between 0 and limit - 1, not {vhost}"
  166. ) from exc
  167. return vhost
  168. def _prepare_master_name(self, broker_options):
  169. try:
  170. master_name = broker_options["master_name"]
  171. except KeyError as exc:
  172. raise ValueError("master_name is required for Sentinel broker") from exc
  173. return master_name
  174. def _get_redis_client(self, broker_options, broker_use_ssl):
  175. connection_kwargs = {
  176. "password": self.password,
  177. "sentinel_kwargs": broker_options.get("sentinel_kwargs"),
  178. }
  179. if isinstance(broker_use_ssl, dict):
  180. connection_kwargs["ssl"] = True
  181. connection_kwargs.update(broker_use_ssl)
  182. # get all sentinel hosts from Celery App config and use them to initialize Sentinel
  183. sentinel = redis.sentinel.Sentinel(
  184. [(self.host, self.port)], **connection_kwargs
  185. )
  186. redis_client = sentinel.master_for(self.master_name)
  187. return redis_client
  188. class RedisSocket(RedisBase):
  189. def __init__(self, broker_url, *args, **kwargs):
  190. super().__init__(broker_url, *args, **kwargs)
  191. self.redis = redis.Redis(
  192. unix_socket_path="/" + self.vhost, password=self.password
  193. )
  194. class RedisSsl(Redis):
  195. """
  196. Redis SSL class offering connection to the broker over SSL.
  197. This does not currently support SSL settings through the url, only through
  198. the broker_use_ssl celery configuration.
  199. """
  200. def __init__(self, broker_url, *args, **kwargs):
  201. if "broker_use_ssl" not in kwargs:
  202. raise ValueError("rediss broker requires broker_use_ssl")
  203. self.broker_use_ssl = kwargs.get("broker_use_ssl", {})
  204. super().__init__(broker_url, *args, **kwargs)
  205. def _get_redis_client_args(self):
  206. client_args = super()._get_redis_client_args()
  207. client_args["ssl"] = True
  208. if isinstance(self.broker_use_ssl, dict):
  209. client_args.update(self.broker_use_ssl)
  210. return client_args
  211. class Broker:
  212. """Factory returning the appropriate broker client based on URL scheme.
  213. Supported schemes:
  214. ``amqp`` or ``amqps`` -> :class:`RabbitMQ`
  215. ``redis`` -> :class:`Redis`
  216. ``rediss`` -> :class:`RedisSsl`
  217. ``redis+socket`` -> :class:`RedisSocket`
  218. ``sentinel`` -> :class:`RedisSentinel`
  219. """
  220. def __new__(cls, broker_url, *args, **kwargs):
  221. scheme = urlparse(broker_url).scheme
  222. if scheme in ("amqp", "amqps"):
  223. return RabbitMQ(broker_url, *args, **kwargs)
  224. if scheme == "redis":
  225. return Redis(broker_url, *args, **kwargs)
  226. if scheme == "rediss":
  227. return RedisSsl(broker_url, *args, **kwargs)
  228. if scheme == "redis+socket":
  229. return RedisSocket(broker_url, *args, **kwargs)
  230. if scheme == "sentinel":
  231. return RedisSentinel(broker_url, *args, **kwargs)
  232. raise NotImplementedError
  233. async def queues(self, names):
  234. raise NotImplementedError