llm.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. import logging
  2. from typing import Any, Dict, Optional
  3. from ray._common.deprecation import Deprecated
  4. from ray.data.block import UserDefinedFunction
  5. from ray.llm._internal.batch.processor import (
  6. HttpRequestProcessorConfig as _HttpRequestProcessorConfig,
  7. Processor,
  8. ProcessorConfig as _ProcessorConfig,
  9. ServeDeploymentProcessorConfig as _ServeDeploymentProcessorConfig,
  10. SGLangEngineProcessorConfig as _SGLangEngineProcessorConfig,
  11. vLLMEngineProcessorConfig as _vLLMEngineProcessorConfig,
  12. )
  13. from ray.llm._internal.batch.stages.configs import (
  14. ChatTemplateStageConfig as _ChatTemplateStageConfig,
  15. DetokenizeStageConfig as _DetokenizeStageConfig,
  16. HttpRequestStageConfig as _HttpRequestStageConfig,
  17. PrepareImageStageConfig as _PrepareImageStageConfig,
  18. PrepareMultimodalStageConfig as _PrepareMultimodalStageConfig,
  19. TokenizerStageConfig as _TokenizerStageConfig,
  20. )
  21. from ray.util.annotations import PublicAPI
  22. logger = logging.getLogger(__name__)
  23. @PublicAPI(stability="alpha")
  24. class ProcessorConfig(_ProcessorConfig):
  25. """The processor configuration.
  26. Args:
  27. batch_size: Configures batch size for the processor. Large batch sizes are
  28. likely to saturate the compute resources and could achieve higher throughput.
  29. On the other hand, small batch sizes are more fault-tolerant and could
  30. reduce bubbles in the data pipeline. You can tune the batch size to balance
  31. the throughput and fault-tolerance based on your use case.
  32. resources_per_bundle: The resource bundles for placement groups.
  33. You can specify a custom device label e.g. {'NPU': 1}.
  34. The default resource bundle for LLM Stage is always a GPU resource i.e. {'GPU': 1}.
  35. accelerator_type: The accelerator type used by the LLM stage in a processor.
  36. Default to None, meaning that only the CPU will be used.
  37. concurrency: The number of workers for data parallelism. Default to 1.
  38. If ``concurrency`` is a ``tuple`` ``(m, n)``, Ray creates an autoscaling
  39. actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
  40. If ``concurrency`` is an ``int`` ``n``, Ray uses either a fixed pool of ``n``
  41. workers or an autoscaling pool from ``1`` to ``n`` workers, depending on
  42. the processor and stage.
  43. """
  44. pass
  45. @PublicAPI(stability="alpha")
  46. class HttpRequestProcessorConfig(_HttpRequestProcessorConfig):
  47. """The configuration for the HTTP request processor.
  48. Args:
  49. batch_size: The batch size to send to the HTTP request.
  50. url: The URL to send the HTTP request to.
  51. headers: The headers to send with the HTTP request.
  52. concurrency: The number of concurrent requests to send. Default to 1.
  53. If ``concurrency`` is an ``int`` ``n``, a fixed pool of ``n`` workers is used.
  54. If ``concurrency`` is a ``tuple`` ``(m, n)``, autoscaling strategy
  55. is used (``1 <= m <= n``).
  56. Examples:
  57. .. testcode::
  58. :skipif: True
  59. import ray
  60. from ray.data.llm import HttpRequestProcessorConfig, build_processor
  61. config = HttpRequestProcessorConfig(
  62. url="https://api.openai.com/v1/chat/completions",
  63. headers={"Authorization": "Bearer sk-..."},
  64. concurrency=1,
  65. )
  66. processor = build_processor(
  67. config,
  68. preprocess=lambda row: dict(
  69. payload=dict(
  70. model="gpt-4o-mini",
  71. messages=[
  72. {"role": "system", "content": "You are a calculator"},
  73. {"role": "user", "content": f"{row['id']} ** 3 = ?"},
  74. ],
  75. temperature=0.3,
  76. max_tokens=20,
  77. ),
  78. ),
  79. postprocess=lambda row: dict(
  80. resp=row["http_response"]["choices"][0]["message"]["content"],
  81. ),
  82. )
  83. ds = ray.data.range(10)
  84. ds = processor(ds)
  85. for row in ds.take_all():
  86. print(row)
  87. """
  88. pass
  89. @PublicAPI(stability="alpha")
  90. class vLLMEngineProcessorConfig(_vLLMEngineProcessorConfig):
  91. """The configuration for the vLLM engine processor.
  92. Args:
  93. model_source: The model source to use for the vLLM engine.
  94. batch_size: The batch size to send to the vLLM engine. Large batch sizes are
  95. likely to saturate the compute resources and could achieve higher throughput.
  96. On the other hand, small batch sizes are more fault-tolerant and could
  97. reduce bubbles in the data pipeline. You can tune the batch size to balance
  98. the throughput and fault-tolerance based on your use case.
  99. engine_kwargs: The kwargs to pass to the vLLM engine. Default engine kwargs are
  100. pipeline_parallel_size: 1, tensor_parallel_size: 1, max_num_seqs: 128,
  101. distributed_executor_backend: "mp".
  102. task_type: The task type to use. If not specified, will use 'generate' by default.
  103. runtime_env: The runtime environment to use for the vLLM engine. See
  104. :ref:`this doc <handling_dependencies>` for more details.
  105. max_pending_requests: The maximum number of pending requests. If not specified,
  106. will use the default value from the vLLM engine.
  107. max_concurrent_batches: The maximum number of concurrent batches in the engine.
  108. This is to overlap the batch processing to avoid the tail latency of
  109. each batch. The default value may not be optimal when the batch size
  110. or the batch processing latency is too small, but it should be good
  111. enough for batch size >= 64.
  112. should_continue_on_error: If True, continue processing when inference fails for a row
  113. instead of raising an exception. Failed rows will have a non-empty
  114. ``__inference_error__`` column containing the error message, and other
  115. output columns will be empty strings. Error rows bypass postprocess. If False
  116. (default), any inference error will raise an exception.
  117. chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig).
  118. Defaults to True. Use nested config for per-stage control over batch_size,
  119. concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template``
  120. and ``chat_template`` fields are deprecated but still supported.
  121. tokenize_stage: Tokenizer stage config (bool | dict | TokenizerStageConfig).
  122. Defaults to True. Use nested config for per-stage control over batch_size,
  123. concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
  124. ``tokenize`` field is deprecated but still supported.
  125. detokenize_stage: Detokenizer stage config (bool | dict | DetokenizeStageConfig).
  126. Defaults to True. Use nested config for per-stage control over batch_size,
  127. concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
  128. ``detokenize`` field is deprecated but still supported.
  129. prepare_image_stage: Prepare image stage config (bool | dict | PrepareImageStageConfig).
  130. Defaults to False. Use nested config for per-stage control over batch_size,
  131. concurrency, runtime_env, num_cpus, and memory. Both the legacy ``has_image`` field
  132. and ``prepare_image_stage`` are deprecated but still supported. Prefer to use multimodal
  133. processor to process multimodal data instead.
  134. accelerator_type: The accelerator type used by the LLM stage in a processor.
  135. Default to None, meaning that only the CPU will be used.
  136. concurrency: The number of workers for data parallelism. Default to 1.
  137. If ``concurrency`` is a tuple ``(m, n)``, Ray creates an autoscaling
  138. actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
  139. If ``concurrency`` is an ``int`` ``n``, CPU stages use an autoscaling
  140. pool from ``(1, n)``, while GPU stages use a fixed pool of ``n`` workers.
  141. Stage-specific concurrency can be set via nested stage configs.
  142. Examples:
  143. .. testcode::
  144. :skipif: True
  145. import ray
  146. from ray.data.llm import vLLMEngineProcessorConfig, build_processor
  147. config = vLLMEngineProcessorConfig(
  148. model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
  149. engine_kwargs=dict(
  150. enable_prefix_caching=True,
  151. enable_chunked_prefill=True,
  152. max_num_batched_tokens=4096,
  153. ),
  154. concurrency=1,
  155. batch_size=64,
  156. )
  157. processor = build_processor(
  158. config,
  159. preprocess=lambda row: dict(
  160. messages=[
  161. {"role": "system", "content": "You are a calculator"},
  162. {"role": "user", "content": f"{row['id']} ** 3 = ?"},
  163. ],
  164. sampling_params=dict(
  165. temperature=0.3,
  166. max_tokens=20,
  167. detokenize=False,
  168. ),
  169. ),
  170. postprocess=lambda row: dict(
  171. resp=row["generated_text"],
  172. ),
  173. )
  174. # The processor requires specific input columns, which depend on
  175. # your processor config. You can use the following API to check
  176. # the required input columns:
  177. processor.log_input_column_names()
  178. # Example log:
  179. # The first stage of the processor is ChatTemplateStage.
  180. # Required input columns:
  181. # messages: A list of messages in OpenAI chat format.
  182. ds = ray.data.range(300)
  183. ds = processor(ds)
  184. for row in ds.take_all():
  185. print(row)
  186. """
  187. pass
  188. @PublicAPI(stability="alpha")
  189. class SGLangEngineProcessorConfig(_SGLangEngineProcessorConfig):
  190. """The configuration for the SGLang engine processor.
  191. Args:
  192. model_source: The model source to use for the SGLang engine.
  193. batch_size: The batch size to send to the SGLang engine. Large batch sizes are
  194. likely to saturate the compute resources and could achieve higher throughput.
  195. On the other hand, small batch sizes are more fault-tolerant and could
  196. reduce bubbles in the data pipeline. You can tune the batch size to balance
  197. the throughput and fault-tolerance based on your use case.
  198. engine_kwargs: The kwargs to pass to the SGLang engine. Default engine kwargs are
  199. tp_size: 1, dp_size: 1, skip_tokenizer_init: True.
  200. task_type: The task type to use. If not specified, will use 'generate' by default.
  201. runtime_env: The runtime environment to use for the SGLang engine. See
  202. :ref:`this doc <handling_dependencies>` for more details.
  203. max_pending_requests: The maximum number of pending requests. If not specified,
  204. will use the default value from the SGLang engine.
  205. max_concurrent_batches: The maximum number of concurrent batches in the engine.
  206. This is to overlap the batch processing to avoid the tail latency of
  207. each batch. The default value may not be optimal when the batch size
  208. or the batch processing latency is too small, but it should be good
  209. enough for batch size >= 64.
  210. chat_template_stage: Chat templating stage config (bool | dict | ChatTemplateStageConfig).
  211. Defaults to True. Use nested config for per-stage control over batch_size,
  212. concurrency, runtime_env, num_cpus, and memory. Legacy ``apply_chat_template``
  213. and ``chat_template`` fields are deprecated but still supported.
  214. tokenize_stage: Tokenizer stage config (bool | dict | TokenizerStageConfig).
  215. Defaults to True. Use nested config for per-stage control over batch_size,
  216. concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
  217. ``tokenize`` field is deprecated but still supported.
  218. detokenize_stage: Detokenizer stage config (bool | dict | DetokenizeStageConfig).
  219. Defaults to True. Use nested config for per-stage control over batch_size,
  220. concurrency, runtime_env, num_cpus, memory, and model_source. Legacy
  221. ``detokenize`` field is deprecated but still supported.
  222. accelerator_type: The accelerator type used by the LLM stage in a processor.
  223. Default to None, meaning that only the CPU will be used.
  224. concurrency: The number of workers for data parallelism. Default to 1.
  225. If ``concurrency`` is a tuple ``(m, n)``, Ray creates an autoscaling
  226. actor pool that scales between ``m`` and ``n`` workers (``1 <= m <= n``).
  227. If ``concurrency`` is an ``int`` ``n``, CPU stages use an autoscaling
  228. pool from ``(1, n)``, while GPU stages use a fixed pool of ``n`` workers.
  229. Stage-specific concurrency can be set via nested stage configs.
  230. Examples:
  231. .. testcode::
  232. :skipif: True
  233. import ray
  234. from ray.data.llm import SGLangEngineProcessorConfig, build_processor
  235. config = SGLangEngineProcessorConfig(
  236. model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
  237. engine_kwargs=dict(
  238. dtype="half",
  239. ),
  240. concurrency=1,
  241. batch_size=64,
  242. )
  243. processor = build_processor(
  244. config,
  245. preprocess=lambda row: dict(
  246. messages=[
  247. {"role": "system", "content": "You are a calculator"},
  248. {"role": "user", "content": f"{row['id']} ** 3 = ?"},
  249. ],
  250. sampling_params=dict(
  251. temperature=0.3,
  252. max_new_tokens=20,
  253. ),
  254. ),
  255. postprocess=lambda row: dict(
  256. resp=row["generated_text"],
  257. ),
  258. )
  259. ds = ray.data.range(300)
  260. ds = processor(ds)
  261. for row in ds.take_all():
  262. print(row)
  263. """
  264. pass
  265. @PublicAPI(stability="alpha")
  266. class ServeDeploymentProcessorConfig(_ServeDeploymentProcessorConfig):
  267. """The configuration for the serve deployment processor.
  268. This processor enables sharing serve deployments across multiple processors. This is useful
  269. for sharing the same LLM engine across multiple processors.
  270. Args:
  271. deployment_name: The name of the serve deployment to use.
  272. app_name: The name of the serve application to use.
  273. batch_size: The batch size to send to the serve deployment. Large batch sizes are
  274. likely to saturate the compute resources and could achieve higher throughput.
  275. On the other hand, small batch sizes are more fault-tolerant and could
  276. reduce bubbles in the data pipeline. You can tune the batch size to balance
  277. the throughput and fault-tolerance based on your use case.
  278. dtype_mapping: The mapping of the request class name to the request class. If this is
  279. not provided, the serve deployment is expected to accept a dict as the request.
  280. concurrency: The number of workers for data parallelism. Default to 1. Note that this is
  281. not the concurrency of the underlying serve deployment.
  282. If ``concurrency`` is an ``int`` ``n``, a fixed pool of ``n`` workers is used.
  283. If ``concurrency`` is a ``tuple`` ``(m, n)``, autoscaling strategy
  284. is used (``1 <= m <= n``).
  285. Examples:
  286. .. testcode::
  287. :skipif: True
  288. import ray
  289. from ray import serve
  290. from ray.data.llm import ServeDeploymentProcessorConfig, build_processor
  291. from ray.serve.llm import (
  292. LLMConfig,
  293. ModelLoadingConfig,
  294. build_llm_deployment,
  295. )
  296. from ray.serve.llm.openai_api_models import CompletionRequest
  297. llm_config = LLMConfig(
  298. model_loading_config=ModelLoadingConfig(
  299. model_id="facebook/opt-1.3b",
  300. model_source="facebook/opt-1.3b",
  301. ),
  302. accelerator_type="A10G",
  303. deployment_config=dict(
  304. name="facebook",
  305. autoscaling_config=dict(
  306. min_replicas=1,
  307. max_replicas=1,
  308. ),
  309. ),
  310. engine_kwargs=dict(
  311. enable_prefix_caching=True,
  312. enable_chunked_prefill=True,
  313. max_num_batched_tokens=4096,
  314. ),
  315. )
  316. APP_NAME = "facebook_opt_app"
  317. DEPLOYMENT_NAME = "facebook_deployment"
  318. override_serve_options = dict(name=DEPLOYMENT_NAME)
  319. llm_app = build_llm_deployment(
  320. llm_config, override_serve_options=override_serve_options
  321. )
  322. app = serve.run(llm_app, name=APP_NAME)
  323. config = ServeDeploymentProcessorConfig(
  324. deployment_name=DEPLOYMENT_NAME,
  325. app_name=APP_NAME,
  326. dtype_mapping={
  327. "CompletionRequest": CompletionRequest,
  328. },
  329. concurrency=1,
  330. batch_size=64,
  331. )
  332. processor = build_processor(
  333. config,
  334. preprocess=lambda row: dict(
  335. method="completions",
  336. dtype="CompletionRequest",
  337. request_kwargs=dict(
  338. model="facebook/opt-1.3b",
  339. prompt=f"This is a prompt for {row['id']}",
  340. stream=False,
  341. ),
  342. ),
  343. postprocess=lambda row: dict(
  344. resp=row["choices"][0]["text"],
  345. ),
  346. )
  347. # The processor requires specific input columns, which depend on
  348. # your processor config. You can use the following API to check
  349. # the required input columns:
  350. processor.log_input_column_names()
  351. ds = ray.data.range(10)
  352. ds = processor(ds)
  353. for row in ds.take_all():
  354. print(row)
  355. """
  356. pass
  357. @PublicAPI(stability="alpha")
  358. class ChatTemplateStageConfig(_ChatTemplateStageConfig):
  359. """The configuration for the chat template stage.
  360. Args:
  361. enabled: Whether this stage is enabled. Defaults to True.
  362. model_source: Model source/identifier for this stage. If not specified,
  363. will use the processor-level model_source.
  364. chat_template: The chat template in Jinja template format. This is
  365. usually not needed if the model checkpoint already contains the
  366. chat template.
  367. chat_template_kwargs: Optional kwargs to pass to apply_chat_template.
  368. batch_size: Rows per batch. If not specified, will use the processor-level
  369. batch_size.
  370. concurrency: Actor pool size or range for this stage. If not specified,
  371. will use the processor-level concurrency. If ``concurrency`` is a
  372. tuple ``(m, n)``, Ray creates an autoscaling actor pool that scales
  373. between ``m`` and ``n`` workers (``1 <= m <= n``). If ``concurrency``
  374. is an ``int`` ``n``, CPU stages use an autoscaling pool from ``(1, n)``.
  375. runtime_env: Optional runtime environment for this stage. If not specified,
  376. will use the processor-level runtime_env. See
  377. :ref:`this doc <handling_dependencies>` for more details.
  378. num_cpus: Number of CPUs to reserve for each map worker in this stage.
  379. memory: Heap memory in bytes to reserve for each map worker in this stage.
  380. """
  381. pass
  382. @PublicAPI(stability="alpha")
  383. class DetokenizeStageConfig(_DetokenizeStageConfig):
  384. """The configuration for the detokenize stage.
  385. Args:
  386. enabled: Whether this stage is enabled. Defaults to True.
  387. model_source: Model source/identifier for this stage. If not specified,
  388. will use the processor-level model_source.
  389. batch_size: Rows per batch. If not specified, will use the processor-level
  390. batch_size.
  391. concurrency: Actor pool size or range for this stage. If not specified,
  392. will use the processor-level concurrency. If ``concurrency`` is a
  393. tuple ``(m, n)``, Ray creates an autoscaling actor pool that scales
  394. between ``m`` and ``n`` workers (``1 <= m <= n``). If ``concurrency``
  395. is an ``int`` ``n``, CPU stages use an autoscaling pool from ``(1, n)``.
  396. runtime_env: Optional runtime environment for this stage. If not specified,
  397. will use the processor-level runtime_env. See
  398. :ref:`this doc <handling_dependencies>` for more details.
  399. num_cpus: Number of CPUs to reserve for each map worker in this stage.
  400. memory: Heap memory in bytes to reserve for each map worker in this stage.
  401. """
  402. pass
  403. @PublicAPI(stability="alpha")
  404. class PrepareMultimodalStageConfig(_PrepareMultimodalStageConfig):
  405. """The configuration for the prepare multimodal stage.
  406. Args:
  407. enabled: Whether this stage is enabled. Defaults to True.
  408. model_config_kwargs: Optional kwargs to pass to the model config.
  409. See available model config kwargs at
  410. https://docs.vllm.ai/en/latest/api/vllm/config/#vllm.config.ModelConfig.
  411. chat_template_content_format: The content format to use for the chat
  412. template. This is used to format the chat template content according
  413. to a specific model. Choices are "string" or "openai". Defaults to
  414. "string".
  415. apply_sys_msg_formatting: Whether to apply formatting system messages.
  416. Defaults to False.
  417. batch_size: Rows per batch. If not specified, will use the processor-level
  418. batch_size.
  419. concurrency: Actor pool size or range for this stage. If not specified,
  420. will use the processor-level concurrency. If ``concurrency`` is a
  421. tuple ``(m, n)``, Ray creates an autoscaling actor pool that scales
  422. between ``m`` and ``n`` workers (``1 <= m <= n``). If ``concurrency``
  423. is an ``int`` ``n``, CPU stages use an autoscaling pool from ``(1, n)``.
  424. runtime_env: Optional runtime environment for this stage. If not specified,
  425. will use the processor-level runtime_env. See
  426. :ref:`this doc <handling_dependencies>` for more details.
  427. num_cpus: Number of CPUs to reserve for each map worker in this stage.
  428. memory: Heap memory in bytes to reserve for each map worker in this stage.
  429. """
  430. pass
  431. @PublicAPI(stability="alpha")
  432. class TokenizerStageConfig(_TokenizerStageConfig):
  433. """The configuration for the tokenizer stage.
  434. Args:
  435. enabled: Whether this stage is enabled. Defaults to True.
  436. model_source: Model source/identifier for this stage. If not specified,
  437. will use the processor-level model_source.
  438. batch_size: Rows per batch. If not specified, will use the processor-level
  439. batch_size.
  440. concurrency: Actor pool size or range for this stage. If not specified,
  441. will use the processor-level concurrency. If ``concurrency`` is a
  442. tuple ``(m, n)``, Ray creates an autoscaling actor pool that scales
  443. between ``m`` and ``n`` workers (``1 <= m <= n``). If ``concurrency``
  444. is an ``int`` ``n``, CPU stages use an autoscaling pool from ``(1, n)``.
  445. runtime_env: Optional runtime environment for this stage. If not specified,
  446. will use the processor-level runtime_env. See
  447. :ref:`this doc <handling_dependencies>` for more details.
  448. num_cpus: Number of CPUs to reserve for each map worker in this stage.
  449. memory: Heap memory in bytes to reserve for each map worker in this stage.
  450. """
  451. pass
  452. @PublicAPI(stability="alpha")
  453. class HttpRequestStageConfig(_HttpRequestStageConfig):
  454. """The configuration for the http request stage.
  455. Args:
  456. enabled: Whether this stage is enabled. Defaults to True.
  457. batch_size: Rows per batch. If not specified, will use the processor-level
  458. batch_size.
  459. concurrency: Actor pool size or range for this stage. If not specified,
  460. will use the processor-level concurrency. If ``concurrency`` is a
  461. tuple ``(m, n)``, Ray creates an autoscaling actor pool that scales
  462. between ``m`` and ``n`` workers (``1 <= m <= n``). If ``concurrency``
  463. is an ``int`` ``n``, CPU stages use an autoscaling pool from ``(1, n)``.
  464. runtime_env: Optional runtime environment for this stage. If not specified,
  465. will use the processor-level runtime_env. See
  466. :ref:`this doc <handling_dependencies>` for more details.
  467. num_cpus: Number of CPUs to reserve for each map worker in this stage.
  468. memory: Heap memory in bytes to reserve for each map worker in this stage.
  469. """
  470. pass
  471. @PublicAPI(stability="alpha")
  472. class PrepareImageStageConfig(_PrepareImageStageConfig):
  473. """The configuration for the prepare image stage.
  474. Args:
  475. enabled: Whether this stage is enabled. Defaults to True.
  476. batch_size: Rows per batch. If not specified, will use the processor-level
  477. batch_size.
  478. concurrency: Actor pool size or range for this stage. If not specified,
  479. will use the processor-level concurrency. If ``concurrency`` is a
  480. tuple ``(m, n)``, Ray creates an autoscaling actor pool that scales
  481. between ``m`` and ``n`` workers (``1 <= m <= n``). If ``concurrency``
  482. is an ``int`` ``n``, CPU stages use an autoscaling pool from ``(1, n)``.
  483. runtime_env: Optional runtime environment for this stage. If not specified,
  484. will use the processor-level runtime_env. See
  485. :ref:`this doc <handling_dependencies>` for more details.
  486. num_cpus: Number of CPUs to reserve for each map worker in this stage.
  487. memory: Heap memory in bytes to reserve for each map worker in this stage.
  488. """
  489. pass
  490. @Deprecated(new="build_processor", error=False)
  491. def build_llm_processor(
  492. config: ProcessorConfig,
  493. preprocess: Optional[UserDefinedFunction] = None,
  494. postprocess: Optional[UserDefinedFunction] = None,
  495. preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
  496. postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
  497. builder_kwargs: Optional[Dict[str, Any]] = None,
  498. ) -> Processor:
  499. """
  500. [DEPRECATED] Prefer build_processor. Build a LLM processor using the given config.
  501. """
  502. return build_processor(
  503. config,
  504. preprocess,
  505. postprocess,
  506. preprocess_map_kwargs,
  507. postprocess_map_kwargs,
  508. builder_kwargs,
  509. )
  510. @PublicAPI(stability="alpha")
  511. def build_processor(
  512. config: ProcessorConfig,
  513. preprocess: Optional[UserDefinedFunction] = None,
  514. postprocess: Optional[UserDefinedFunction] = None,
  515. preprocess_map_kwargs: Optional[Dict[str, Any]] = None,
  516. postprocess_map_kwargs: Optional[Dict[str, Any]] = None,
  517. builder_kwargs: Optional[Dict[str, Any]] = None,
  518. ) -> Processor:
  519. """Build a processor using the given config.
  520. Args:
  521. config: The processor config. Supports nested stage configs for per-stage
  522. control over batch_size, concurrency, runtime_env, num_cpus, and memory
  523. (e.g., ``chat_template_stage=ChatTemplateStageConfig(batch_size=128)``
  524. or ``tokenize_stage={"batch_size": 256, "concurrency": 2}``). Legacy
  525. boolean flags (``apply_chat_template``, ``tokenize``, ``detokenize``,
  526. ``has_image``) are deprecated but still supported with deprecation warnings.
  527. preprocess: An optional lambda function that takes a row (dict) as input
  528. and returns a preprocessed row (dict). The output row must contain the
  529. required fields for the following processing stages. Each row
  530. can contain a `sampling_params` or `pooling_params` field which will be used
  531. by the engine for row-specific sampling or pooling parameters respectively.
  532. Note that all columns will be carried over until the postprocess stage.
  533. postprocess: An optional lambda function that takes a row (dict) as input
  534. and returns a postprocessed row (dict). To keep all the original columns,
  535. you can use the `**row` syntax to return all the original columns.
  536. preprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
  537. preprocess stage. Useful for controlling resources (e.g., num_cpus=0.5)
  538. and concurrency independently of the main LLM stage.
  539. postprocess_map_kwargs: Optional kwargs to pass to Dataset.map() for the
  540. postprocess stage. Useful for controlling resources (e.g., num_cpus=0.25)
  541. and concurrency independently of the main LLM stage.
  542. builder_kwargs: Optional additional kwargs to pass to the processor builder
  543. function. These will be passed through to the registered builder and
  544. should match the signature of the specific builder being used.
  545. For example, vLLM and SGLang processors support `chat_template_kwargs`.
  546. Returns:
  547. The built processor.
  548. Examples:
  549. Basic usage:
  550. .. testcode::
  551. :skipif: True
  552. import ray
  553. from ray.data.llm import vLLMEngineProcessorConfig, build_processor
  554. config = vLLMEngineProcessorConfig(
  555. model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
  556. engine_kwargs=dict(
  557. enable_prefix_caching=True,
  558. enable_chunked_prefill=True,
  559. max_num_batched_tokens=4096,
  560. ),
  561. concurrency=1,
  562. batch_size=64,
  563. )
  564. processor = build_processor(
  565. config,
  566. preprocess=lambda row: dict(
  567. messages=[
  568. {"role": "system", "content": "You are a calculator"},
  569. {"role": "user", "content": f"{row['id']} ** 3 = ?"},
  570. ],
  571. sampling_params=dict(
  572. temperature=0.3,
  573. max_tokens=20,
  574. detokenize=False,
  575. ),
  576. ),
  577. postprocess=lambda row: dict(
  578. resp=row["generated_text"],
  579. **row, # This will return all the original columns in the dataset.
  580. ),
  581. )
  582. ds = ray.data.range(300)
  583. ds = processor(ds)
  584. for row in ds.take_all():
  585. print(row)
  586. Using map_kwargs to control preprocess/postprocess resources:
  587. .. testcode::
  588. :skipif: True
  589. import ray
  590. from ray.data.llm import vLLMEngineProcessorConfig, build_processor
  591. config = vLLMEngineProcessorConfig(
  592. model_source="meta-llama/Meta-Llama-3.1-8B-Instruct",
  593. concurrency=1,
  594. batch_size=64,
  595. )
  596. processor = build_processor(
  597. config,
  598. preprocess=lambda row: dict(
  599. messages=[{"role": "user", "content": row["prompt"]}],
  600. sampling_params=dict(temperature=0.3, max_tokens=20),
  601. ),
  602. postprocess=lambda row: dict(resp=row["generated_text"]),
  603. preprocess_map_kwargs={"num_cpus": 0.5},
  604. postprocess_map_kwargs={"num_cpus": 0.25},
  605. )
  606. ds = ray.data.range(300)
  607. ds = processor(ds)
  608. for row in ds.take_all():
  609. print(row)
  610. Using builder_kwargs to pass chat_template_kwargs:
  611. .. testcode::
  612. :skipif: True
  613. import ray
  614. from ray.data.llm import vLLMEngineProcessorConfig, build_processor
  615. config = vLLMEngineProcessorConfig(
  616. model_source="Qwen/Qwen3-0.6B",
  617. chat_template_stage={"enabled": True},
  618. concurrency=1,
  619. batch_size=64,
  620. )
  621. processor = build_processor(
  622. config,
  623. preprocess=lambda row: dict(
  624. messages=[
  625. {"role": "user", "content": row["prompt"]},
  626. ],
  627. sampling_params=dict(
  628. temperature=0.6,
  629. max_tokens=100,
  630. ),
  631. ),
  632. builder_kwargs=dict(
  633. chat_template_kwargs={"enable_thinking": True},
  634. ),
  635. )
  636. ds = ray.data.from_items([{"prompt": "What is 2+2?"}])
  637. ds = processor(ds)
  638. for row in ds.take_all():
  639. print(row)
  640. """
  641. from ray.llm._internal.batch.processor import ProcessorBuilder
  642. ProcessorBuilder.validate_builder_kwargs(builder_kwargs)
  643. build_kwargs = dict(
  644. preprocess=preprocess,
  645. postprocess=postprocess,
  646. preprocess_map_kwargs=preprocess_map_kwargs,
  647. postprocess_map_kwargs=postprocess_map_kwargs,
  648. )
  649. # Pass through any additional builder kwargs
  650. if builder_kwargs is not None:
  651. build_kwargs.update(builder_kwargs)
  652. return ProcessorBuilder.build(config, **build_kwargs)
  653. __all__ = [
  654. "ProcessorConfig",
  655. "Processor",
  656. "HttpRequestProcessorConfig",
  657. "vLLMEngineProcessorConfig",
  658. "SGLangEngineProcessorConfig",
  659. "ServeDeploymentProcessorConfig",
  660. "ChatTemplateStageConfig",
  661. "DetokenizeStageConfig",
  662. "PrepareMultimodalStageConfig",
  663. "TokenizerStageConfig",
  664. "HttpRequestStageConfig",
  665. "PrepareImageStageConfig",
  666. "build_llm_processor",
  667. "build_processor",
  668. ]