requests.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. # Copyright 2025 The HuggingFace Inc. team
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import time
  15. from dataclasses import dataclass, field
  16. from enum import IntEnum
  17. import torch
  18. from ...utils import is_psutil_available, is_torch_xpu_available
  19. from ...utils.logging import logging
  20. from ...utils.metrics import traced
  21. if is_psutil_available():
  22. import psutil
  23. # This is a temporary token ID used to represent a token that is not yet generated
  24. TMP_TOKEN_ID = -1
  25. # We centralize the logger here to coordinate between logging and progress bar
  26. logger = logging.getLogger("ContinuousBatchingLogger")
  27. # Add a handler to the logger to print the logs to the console. Only happens once thanks to setting propagate to False.
  28. if logger.propagate:
  29. handler = logging.StreamHandler()
  30. handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
  31. logger.addHandler(handler)
  32. logger.propagate = False
  33. def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
  34. if torch.cuda.is_available():
  35. device = torch.device("cuda")
  36. torch.cuda.empty_cache()
  37. torch.cuda.synchronize()
  38. total_memory = torch.cuda.get_device_properties(device).total_memory
  39. reserved_memory = torch.cuda.memory_reserved(device)
  40. allocated_memory = torch.cuda.memory_allocated(device)
  41. elif is_torch_xpu_available():
  42. device = torch.device("xpu")
  43. torch.xpu.empty_cache()
  44. torch.xpu.synchronize()
  45. total_memory = torch.xpu.get_device_properties(device).total_memory
  46. reserved_memory = torch.xpu.memory_reserved(device)
  47. allocated_memory = torch.xpu.memory_allocated(device)
  48. elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
  49. device = torch.device("mps")
  50. # MPS memory reporting (PyTorch 2.0+)
  51. total_memory = torch.mps.driver_allocated_memory()
  52. allocated_memory = total_memory - getattr(torch.mps, "recommended_max_memory")()
  53. reserved_memory = 0 # MPS does not track reserved separately
  54. else:
  55. device = torch.device("cpu")
  56. if is_psutil_available():
  57. total_memory = psutil.virtual_memory().total
  58. allocated_memory = psutil.Process().memory_info().rss
  59. reserved_memory = allocated_memory
  60. else:
  61. logger.error(
  62. "Cannot get memory breakdown on CPU without psutil: returning 0 for all memory values. Please install "
  63. "psutil to get an actual memory breakdown."
  64. )
  65. total_memory = 0
  66. reserved_memory = 0
  67. allocated_memory = 0
  68. return device, total_memory, reserved_memory, allocated_memory
  69. class RequestStatus(IntEnum):
  70. """Status of a generation request through its lifecycle."""
  71. PENDING = 0
  72. PREFILLING = 1
  73. DECODING = 2
  74. FINISHED = 3
  75. FAILED = 4
  76. @dataclass
  77. class GenerationOutput:
  78. """Tracks the output of a generation request.
  79. Attributes:
  80. request_id (str): The ID of the generation request.
  81. prompt_ids (list[int]): The IDs of the prompt tokens.
  82. generated_tokens (list[int]): The generated tokens.
  83. logprobs (list[float]): The log probabilities of the generated tokens.
  84. error (Optional[str]): Any error message associated with the request. When None, the request was successful.
  85. status (RequestStatus): The status of the request.
  86. created_time (float): The time the request was created.
  87. lifespan (tuple[float, float]): The time the request was no longer pending and the time the request finished.
  88. """
  89. request_id: str
  90. prompt_ids: list[int] = field(default_factory=list)
  91. generated_tokens: list[int] = field(default_factory=list)
  92. logprobs: list[float] = field(default_factory=list)
  93. error: str | None = None
  94. status: RequestStatus = RequestStatus.PENDING
  95. created_time: float = field(default_factory=time.perf_counter)
  96. lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
  97. timestamps: list[float] | None = None # Timestamps of the generated tokens
  98. def is_finished(self) -> bool:
  99. return self.status == RequestStatus.FINISHED
  100. @dataclass
  101. class RequestState:
  102. """Tracks the state of a generation request through its lifecycle.
  103. Attributes:
  104. request_id (str): The ID of the generation request.
  105. initial_tokens (list[int]): The initial prompt tokens.
  106. num_children (int): The number of children requests
  107. full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
  108. prompt_ids (list[int] | None): The tokens IDs currently being processed.
  109. remaining_prompt_ids (list[int]): The initial tokens IDs remaining to be processed.
  110. static_outputs (list[int]): The generated tokens.
  111. allocated_blocks (int): The number of blocks allocated to the request.
  112. position_offset (int): The current position in the sequence for position_ids.
  113. status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
  114. SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
  115. max_new_tokens (int | None): The maximum number of new tokens to generate.
  116. eos_token_id (None | int | list[int]): The ID(s) of the end-of-sequence tokens. Only used in post-init.
  117. _eos_token_ids (set[int]): The IDs of the end-of-sequence tokens, formatted as a set.
  118. streaming (bool): Whether to stream tokens as they're generated
  119. created_time (float): The time the request was created.
  120. error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
  121. """
  122. # Required fields
  123. request_id: str
  124. initial_tokens: list[int] # Initial prompt tokens # TODO: rename this as prefill tokens
  125. # Optional fields
  126. record_timestamps: bool = False # Whether to record timestamps for the generated tokens
  127. num_children: int = 0 # Number of children requests
  128. # Internal fields
  129. tokens_to_process: list[int] = field(default_factory=list) # Tokens IDs currently being processed
  130. remaining_prefill_tokens: list[int] = field(
  131. default_factory=list
  132. ) # Initial tokens left to process (initialized in __post_init__)
  133. generated_tokens: list[int] = field(default_factory=list) # Generated tokens
  134. logprobs: list[float] = field(default_factory=list) # Log probabilities of the generated tokens
  135. allocated_blocks: int = 0 # Number of blocks allocated to the request
  136. position_offset: int = 0 # Current position in the sequence for position_ids
  137. _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
  138. max_new_tokens: int | None = 20 # Maximum number of new tokens to generate. None means no limit. Default to 20.
  139. eos_token_id: int | list[int] | None = None # ID(s) of the end-of-sequence tokens. Only used in post-init.
  140. _eos_token_ids: set[int] = field(default_factory=set) # IDs of the end-of-sequence tokens, formatted as a set
  141. streaming: bool = False # Whether to stream tokens as they're generated
  142. created_time: float = field(default_factory=time.perf_counter) # Time the request was created
  143. error: str | None = None # Error message if the request failed
  144. lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
  145. _timestamps: list[float] = field(default_factory=list) # Timestamps of the generated tokens
  146. _true_initial_tokens: int = 0 # The true number of initial tokens, useful when soft resetting requests
  147. # TODO: remove the attribute above to _num_initial_tokens once initial_tokens is renamed
  148. _new_tokens_limit: int = 2147483647 # An int to check the max number of new tokens w/out always comparing w/ None
  149. def __post_init__(self):
  150. # If no max length is set, we set an absurdly high value which will never be reached
  151. self._new_tokens_limit = 2147483647 if self.max_new_tokens is None else self.max_new_tokens
  152. # Keep a copy of the initial tokens to process
  153. self.remaining_prefill_tokens = self.initial_tokens[:]
  154. # Format the EOS token ID(s) as a set of ints. If there is no EOS token ID, it's an empty set
  155. if self.eos_token_id is None:
  156. pass
  157. # If there is a single EOS token ID, add it to the set only if the ID is valid, ie. non-negative
  158. elif isinstance(self.eos_token_id, int):
  159. if self.eos_token_id >= 0:
  160. self._eos_token_ids.add(self.eos_token_id)
  161. # If there are multiple EOS token IDs, add them to the set only if they are valid, ie. non-negative
  162. else:
  163. for token_id in self.eos_token_id:
  164. if token_id >= 0:
  165. self._eos_token_ids.add(token_id)
  166. @property
  167. def status(self) -> RequestStatus:
  168. return self._status
  169. @status.setter
  170. def status(self, value: RequestStatus):
  171. if self._status == RequestStatus.PENDING:
  172. self.lifespan = (time.perf_counter(), -1)
  173. elif value == RequestStatus.FINISHED:
  174. self.lifespan = (self.lifespan[0], time.perf_counter())
  175. self.log_end_of_request()
  176. self._status = value
  177. @property
  178. def timestamps(self) -> list[float] | None:
  179. return self._timestamps if self.record_timestamps else None
  180. def log_end_of_request(self):
  181. prefill_len = len(self.initial_tokens)
  182. decode_len = self.generated_len()
  183. start_time = self.lifespan[0] - self.created_time
  184. end_time = self.lifespan[1] - self.created_time
  185. logger.info(
  186. f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
  187. )
  188. def current_len(self) -> int:
  189. """Get the current length of the sequence (prompt + generated tokens)."""
  190. return self.position_offset
  191. def generated_len(self) -> int:
  192. """Get the number of tokens generated so far."""
  193. return len(self.generated_tokens)
  194. # TODO: this logic seems one token off, check it out
  195. @traced
  196. def update_and_check_completion(self, token_id: int, logprob: float | None) -> bool:
  197. """Update the request with a newly generated token (and optional log probability of the token) and check for
  198. completion. Returns True if the request is now complete, False otherwise."""
  199. # Only update if we're in decoding state # TODO: seems useless (always true) -- remove this
  200. if self.status != RequestStatus.DECODING:
  201. return False
  202. # If we're recording timestamps, add timestamp to the list
  203. if self.record_timestamps:
  204. self._timestamps.append(time.perf_counter())
  205. # Stop if we reached an EOS token
  206. is_eos = token_id in self._eos_token_ids
  207. current_len = self.generated_len()
  208. # Replace the temporary token if we're not finishing due to max length
  209. # (EOS tokens should still be added to the output)
  210. if is_eos or (current_len < self._new_tokens_limit):
  211. self.generated_tokens.append(token_id)
  212. self.tokens_to_process = [token_id] # this works for 2 levels of pipelines, but not sure for more
  213. current_len += 1
  214. if logprob is not None:
  215. self.logprobs.append(logprob)
  216. else:
  217. logger.warning(f"Request {self.request_id} generated a useless token: {token_id}")
  218. if is_eos or current_len >= self._new_tokens_limit:
  219. self.status = RequestStatus.FINISHED
  220. return True
  221. return False # We still need to process more tokens
  222. def __repr__(self):
  223. msg = [
  224. f"request_id={self.request_id}",
  225. f"status={self._status}",
  226. f"out_tokens={self.generated_len()}",
  227. f"query_length={len(self.tokens_to_process)}",
  228. f"remaining_tokens={len(self.remaining_prefill_tokens)}",
  229. f"kv_length={self.position_offset}",
  230. f"full_prompt_length={len(self.initial_tokens)}",
  231. f"allocated_blocks={self.allocated_blocks}",
  232. f"generated_tokens={self.generated_tokens}",
  233. ]
  234. return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"
  235. def to_generation_output(self):
  236. """Convert the request state to a GenerationOutput object."""
  237. if self._true_initial_tokens:
  238. self.generated_tokens = self.initial_tokens[self._true_initial_tokens :] + self.generated_tokens
  239. self.initial_tokens = self.initial_tokens[: self._true_initial_tokens]
  240. return GenerationOutput(
  241. request_id=self.request_id,
  242. prompt_ids=self.initial_tokens,
  243. generated_tokens=self.generated_tokens,
  244. logprobs=self.logprobs,
  245. error=self.error,
  246. status=self.status,
  247. created_time=self.created_time,
  248. lifespan=self.lifespan,
  249. timestamps=self.timestamps,
  250. )
  251. def fork(self, new_request_id: str) -> "RequestState":
  252. """Fork the request into a new request with the same state except for request_id, created_time and lifespan."""
  253. t = time.perf_counter()
  254. new_request = RequestState(
  255. request_id=new_request_id,
  256. initial_tokens=self.initial_tokens,
  257. num_children=self.num_children,
  258. tokens_to_process=self.tokens_to_process[:],
  259. generated_tokens=self.generated_tokens[:],
  260. logprobs=self.logprobs[:],
  261. allocated_blocks=self.allocated_blocks,
  262. position_offset=self.position_offset,
  263. _status=self.status,
  264. max_new_tokens=self.max_new_tokens,
  265. eos_token_id=self.eos_token_id,
  266. streaming=self.streaming,
  267. created_time=t,
  268. lifespan=(t, -1),
  269. _timestamps=[],
  270. error=self.error,
  271. record_timestamps=self.record_timestamps,
  272. )
  273. # Modified by __post_init__
  274. new_request.remaining_prefill_tokens = self.remaining_prefill_tokens[:]
  275. return new_request
  276. def create_equivalent_initial_request(self) -> "RequestState":
  277. """Creates an equivalent new request by removing the generated tokens and adding them to the initial prompt. The
  278. created request has THE SAME request_id. Notably, we can retrieve the original request from the created one with
  279. the _true_initial_tokens attribute. The logprobs of the generated tokens are kept in the new request."""
  280. max_new_tokens = None if self.max_new_tokens is None else (self.max_new_tokens - len(self.generated_tokens))
  281. new_state = RequestState(
  282. request_id=self.request_id,
  283. initial_tokens=self.initial_tokens + self.generated_tokens,
  284. logprobs=self.logprobs[:],
  285. num_children=self.num_children,
  286. record_timestamps=self.record_timestamps,
  287. max_new_tokens=max_new_tokens,
  288. eos_token_id=self.eos_token_id,
  289. streaming=self.streaming,
  290. )
  291. # If the request has been soft reset once already, this stays the same
  292. if self._true_initial_tokens:
  293. new_state._true_initial_tokens = self._true_initial_tokens
  294. # Otherwise, we set the true initial tokens to the number of initial tokens
  295. else:
  296. new_state._true_initial_tokens = len(self.initial_tokens)
  297. return new_state
  298. class FutureRequestState:
  299. """Tracks the current state of a request and the relevant information to update it."""
  300. # This makes instantiating this class faster
  301. __slots__ = ("state", "has_new_token", "complete_blocks")
  302. def __init__(self, state: RequestState, has_new_token: bool, complete_blocks: int) -> None:
  303. self.state = state
  304. self.has_new_token = has_new_token
  305. self.complete_blocks = complete_blocks