task_processor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import logging
  2. import threading
  3. import time
  4. from typing import Any, Dict, List, Optional
  5. from celery import Celery
  6. from celery.signals import task_failure, task_unknown
  7. from ray.serve import get_replica_context
  8. from ray.serve._private.constants import (
  9. DEFAULT_CONSUMER_CONCURRENCY,
  10. SERVE_LOGGER_NAME,
  11. )
  12. from ray.serve.schema import (
  13. CeleryAdapterConfig,
  14. TaskProcessorAdapter,
  15. TaskProcessorConfig,
  16. TaskResult,
  17. )
  18. from ray.util.annotations import PublicAPI
  19. logger = logging.getLogger(SERVE_LOGGER_NAME)
  20. CELERY_WORKER_POOL = "worker_pool"
  21. CELERY_WORKER_CONCURRENCY = "worker_concurrency"
  22. CELERY_TASK_IGNORE_RESULT = "task_ignore_result"
  23. CELERY_TASK_ACKS_LATE = "task_acks_late"
  24. CELERY_TASK_REJECT_ON_WORKER_LOST = "task_reject_on_worker_lost"
  25. CELERY_DEFAULT_APP_CONFIG = [
  26. CELERY_WORKER_POOL,
  27. CELERY_WORKER_CONCURRENCY,
  28. CELERY_TASK_IGNORE_RESULT,
  29. CELERY_TASK_ACKS_LATE,
  30. CELERY_TASK_REJECT_ON_WORKER_LOST,
  31. ]
  32. @PublicAPI(stability="alpha")
  33. class CeleryTaskProcessorAdapter(TaskProcessorAdapter):
  34. """
  35. Celery-based task processor adapter.
  36. This adapter does NOT support any async operations.
  37. All operations must be performed synchronously.
  38. """
  39. _app: Celery
  40. _config: TaskProcessorConfig
  41. _worker_thread: Optional[threading.Thread] = None
  42. _worker_hostname: Optional[str] = None
  43. _worker_concurrency: int = DEFAULT_CONSUMER_CONCURRENCY
  44. def __init__(self, config: TaskProcessorConfig, *args, **kwargs):
  45. super().__init__(*args, **kwargs)
  46. if not isinstance(config.adapter_config, CeleryAdapterConfig):
  47. raise TypeError(
  48. "TaskProcessorConfig.adapter_config must be an instance of CeleryAdapterConfig"
  49. )
  50. # Check if any app_custom_config keys conflict with default Celery app config
  51. if config.adapter_config.app_custom_config:
  52. conflicting_keys = set(
  53. config.adapter_config.app_custom_config.keys()
  54. ) & set(CELERY_DEFAULT_APP_CONFIG)
  55. if conflicting_keys:
  56. raise ValueError(
  57. f"The following configuration keys cannot be changed via app_custom_config: {sorted(conflicting_keys)}. "
  58. f"These are managed internally by the CeleryTaskProcessorAdapter."
  59. )
  60. self._config = config
  61. # Celery adapter does not support any async capabilities
  62. # self._async_capabilities is already an empty set from parent class
  63. def initialize(self, consumer_concurrency: int = DEFAULT_CONSUMER_CONCURRENCY):
  64. self._app = Celery(
  65. self._config.queue_name,
  66. backend=self._config.adapter_config.backend_url,
  67. broker=self._config.adapter_config.broker_url,
  68. )
  69. app_configuration = {
  70. CELERY_WORKER_POOL: "threads",
  71. CELERY_WORKER_CONCURRENCY: consumer_concurrency,
  72. CELERY_TASK_IGNORE_RESULT: False, # Store task results so they can be retrieved after completion
  73. CELERY_TASK_ACKS_LATE: True, # Acknowledge tasks only after completion (not when received) for better reliability
  74. CELERY_TASK_REJECT_ON_WORKER_LOST: True, # Reject and requeue tasks when worker is lost to prevent data loss
  75. }
  76. if self._config.adapter_config.app_custom_config:
  77. app_configuration.update(self._config.adapter_config.app_custom_config)
  78. self._app.conf.update(app_configuration)
  79. queue_config = {
  80. self._config.queue_name: {
  81. "exchange": self._config.queue_name,
  82. "exchange_type": "direct",
  83. "routing_key": self._config.queue_name,
  84. },
  85. }
  86. if self._config.failed_task_queue_name:
  87. queue_config[self._config.failed_task_queue_name] = {
  88. "exchange": self._config.failed_task_queue_name,
  89. "exchange_type": "direct",
  90. "routing_key": self._config.failed_task_queue_name,
  91. }
  92. if self._config.unprocessable_task_queue_name:
  93. queue_config[self._config.unprocessable_task_queue_name] = {
  94. "exchange": self._config.unprocessable_task_queue_name,
  95. "exchange_type": "direct",
  96. "routing_key": self._config.unprocessable_task_queue_name,
  97. }
  98. self._app.conf.update(
  99. task_queues=queue_config,
  100. task_routes={
  101. # Default tasks go to main queue
  102. "*": {"queue": self._config.queue_name},
  103. },
  104. )
  105. if self._config.adapter_config.broker_transport_options is not None:
  106. self._app.conf.update(
  107. broker_transport_options=self._config.adapter_config.broker_transport_options,
  108. )
  109. if self._config.failed_task_queue_name:
  110. task_failure.connect(self._handle_task_failure)
  111. if self._config.unprocessable_task_queue_name:
  112. task_unknown.connect(self._handle_unknown_task)
  113. def register_task_handle(self, func, name=None):
  114. task_options = {
  115. "autoretry_for": (Exception,),
  116. "retry_kwargs": {"max_retries": self._config.max_retries},
  117. "retry_backoff": True,
  118. "retry_backoff_max": 60, # Max backoff of 60 seconds
  119. "retry_jitter": False, # Disable jitter for predictable testing
  120. }
  121. if self._config.adapter_config.task_custom_config:
  122. task_options.update(self._config.adapter_config.task_custom_config)
  123. if name:
  124. self._app.task(name=name, **task_options)(func)
  125. else:
  126. self._app.task(**task_options)(func)
  127. def enqueue_task_sync(
  128. self, task_name, args=None, kwargs=None, **options
  129. ) -> TaskResult:
  130. task_response = self._app.send_task(
  131. task_name,
  132. args=args,
  133. kwargs=kwargs,
  134. queue=self._config.queue_name,
  135. **options,
  136. )
  137. return TaskResult(
  138. id=task_response.id,
  139. status=task_response.status,
  140. created_at=time.time(),
  141. result=task_response.result,
  142. )
  143. def get_task_status_sync(self, task_id) -> TaskResult:
  144. task_details = self._app.AsyncResult(task_id)
  145. return TaskResult(
  146. id=task_details.id,
  147. result=task_details.result,
  148. status=task_details.status,
  149. )
  150. def start_consumer(self, **kwargs):
  151. """Starts the Celery worker thread."""
  152. if self._worker_thread is not None and self._worker_thread.is_alive():
  153. logger.info("Celery worker thread is already running.")
  154. return
  155. unique_id = get_replica_context().replica_tag
  156. self._worker_hostname = f"{self._app.main}_{unique_id}"
  157. worker_args = [
  158. "worker",
  159. f"--hostname={self._worker_hostname}",
  160. "-Q",
  161. self._config.queue_name,
  162. ]
  163. self._worker_thread = threading.Thread(
  164. target=self._app.worker_main,
  165. args=(worker_args,),
  166. )
  167. self._worker_thread.start()
  168. logger.info(
  169. f"Celery worker thread started with hostname: {self._worker_hostname}"
  170. )
  171. def stop_consumer(self, timeout: float = 10.0):
  172. """Signals the Celery worker to shut down and waits for it to terminate."""
  173. if self._worker_thread is None or not self._worker_thread.is_alive():
  174. logger.info("Celery worker thread is not running.")
  175. return
  176. logger.info("Sending shutdown signal to Celery worker...")
  177. # Use the worker's hostname for targeted shutdown
  178. self._app.control.broadcast(
  179. "shutdown", destination=[f"celery@{self._worker_hostname}"]
  180. )
  181. self._worker_thread.join(timeout=timeout)
  182. if self._worker_thread.is_alive():
  183. logger.warning(f"Worker thread did not terminate after {timeout} seconds.")
  184. else:
  185. logger.info("Celery worker thread has stopped.")
  186. self._worker_thread = None
  187. def cancel_task_sync(self, task_id):
  188. """
  189. Cancels a task synchronously. Only supported for Redis and RabbitMQ brokers by Celery.
  190. More details can be found here: https://docs.celeryq.dev/en/stable/userguide/workers.html#revoke-revoking-tasks
  191. """
  192. self._app.control.revoke(task_id)
  193. def get_metrics_sync(self) -> Dict[str, Any]:
  194. """
  195. Returns the metrics of the Celery worker synchronously.
  196. More details can be found here: https://docs.celeryq.dev/en/stable/reference/celery.app.control.html#celery.app.control.Inspect.stats
  197. """
  198. return self._app.control.inspect().stats()
  199. def health_check_sync(self) -> List[Dict]:
  200. """
  201. Checks the health of the Celery worker synchronously.
  202. Returns a list of dictionaries, each containing the worker name and a dictionary with the health status.
  203. Example: [{'celery@192.168.1.100': {'ok': 'pong'}}]
  204. More details can be found here: https://docs.celeryq.dev/en/stable/reference/celery.app.control.html#celery.app.control.Control.ping
  205. """
  206. return self._app.control.ping()
  207. def _handle_task_failure(
  208. self,
  209. sender: Any = None,
  210. task_id: str = None,
  211. args: Any = None,
  212. kwargs: Any = None,
  213. einfo: Any = None,
  214. **kw,
  215. ):
  216. """Handle task failures and route them to appropriate dead letter queues.
  217. This method is called when a task fails after all retry attempts have been
  218. exhausted. It logs the failure and moves the task to failed_task_queue
  219. Args:
  220. sender: The task object that failed
  221. task_id: Unique identifier of the failed task
  222. args: Positional arguments passed to the task
  223. kwargs: Keyword arguments passed to the task
  224. einfo: Exception info object containing exception details and traceback
  225. **kw: Additional keyword arguments passed by Celery
  226. """
  227. logger.info(
  228. f"Task failure detected for task_id: {task_id}, einfo: {str(einfo)}"
  229. )
  230. dlq_args = [
  231. task_id,
  232. str(einfo.exception),
  233. str(args),
  234. str(kwargs),
  235. str(einfo),
  236. ]
  237. if self._config.failed_task_queue_name:
  238. self._move_task_to_queue(
  239. self._config.failed_task_queue_name,
  240. sender.name,
  241. dlq_args,
  242. )
  243. logger.error(
  244. f"Task {task_id} failed after max retries. Exception: {einfo}. Moved it to the {self._config.failed_task_queue_name} queue."
  245. )
  246. def _handle_unknown_task(
  247. self,
  248. sender: Any = None,
  249. name: str = None,
  250. id: str = None,
  251. message: Any = None,
  252. exc: Any = None,
  253. **kwargs,
  254. ):
  255. """Handle unknown or unregistered tasks received by Celery.
  256. This method is called when Celery receives a task that it doesn't recognize
  257. (i.e., a task that hasn't been registered with the Celery app). These tasks
  258. are moved to the unprocessable task queue if configured.
  259. Args:
  260. sender: The Celery app or worker that detected the unknown task
  261. name: Name of the unknown task
  262. id: Task ID of the unknown task
  263. message: The raw message received for the unknown task
  264. exc: The exception raised when trying to process the unknown task
  265. **kwargs: Additional context information from Celery
  266. """
  267. logger.info(
  268. f"Unknown task detected by Celery. Name: {name}, ID: {id}, Exc: {str(exc)}"
  269. )
  270. if self._config.unprocessable_task_queue_name:
  271. self._move_task_to_queue(
  272. self._config.unprocessable_task_queue_name,
  273. name,
  274. [
  275. name,
  276. id,
  277. str(message),
  278. str(exc),
  279. str(kwargs),
  280. ],
  281. )
  282. def _move_task_to_queue(self, queue_name: str, task_name: str, args: list):
  283. """Helper function to move a task to a specified queue."""
  284. try:
  285. logger.info(
  286. f"Moving task: {task_name} to queue: {queue_name}, args: {args}"
  287. )
  288. self._app.send_task(
  289. name=task_name,
  290. queue=queue_name,
  291. args=args,
  292. )
  293. except Exception as e:
  294. logger.error(
  295. f"Failed to move task: {task_name} to queue: {queue_name}, error: {e}"
  296. )
  297. raise e