| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347 |
- # Copyright 2025 The HuggingFace Inc. team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import time
- from dataclasses import dataclass, field
- from enum import IntEnum
- import torch
- from ...utils import is_psutil_available, is_torch_xpu_available
- from ...utils.logging import logging
- from ...utils.metrics import traced
- if is_psutil_available():
- import psutil
- # This is a temporary token ID used to represent a token that is not yet generated
- TMP_TOKEN_ID = -1
- # We centralize the logger here to coordinate between logging and progress bar
- logger = logging.getLogger("ContinuousBatchingLogger")
- # Add a handler to the logger to print the logs to the console. Only happens once thanks to setting propagate to False.
- if logger.propagate:
- handler = logging.StreamHandler()
- handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
- logger.addHandler(handler)
- logger.propagate = False
- def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
- if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
- total_memory = torch.cuda.get_device_properties(device).total_memory
- reserved_memory = torch.cuda.memory_reserved(device)
- allocated_memory = torch.cuda.memory_allocated(device)
- elif is_torch_xpu_available():
- device = torch.device("xpu")
- torch.xpu.empty_cache()
- torch.xpu.synchronize()
- total_memory = torch.xpu.get_device_properties(device).total_memory
- reserved_memory = torch.xpu.memory_reserved(device)
- allocated_memory = torch.xpu.memory_allocated(device)
- elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
- device = torch.device("mps")
- # MPS memory reporting (PyTorch 2.0+)
- total_memory = torch.mps.driver_allocated_memory()
- allocated_memory = total_memory - getattr(torch.mps, "recommended_max_memory")()
- reserved_memory = 0 # MPS does not track reserved separately
- else:
- device = torch.device("cpu")
- if is_psutil_available():
- total_memory = psutil.virtual_memory().total
- allocated_memory = psutil.Process().memory_info().rss
- reserved_memory = allocated_memory
- else:
- logger.error(
- "Cannot get memory breakdown on CPU without psutil: returning 0 for all memory values. Please install "
- "psutil to get an actual memory breakdown."
- )
- total_memory = 0
- reserved_memory = 0
- allocated_memory = 0
- return device, total_memory, reserved_memory, allocated_memory
- class RequestStatus(IntEnum):
- """Status of a generation request through its lifecycle."""
- PENDING = 0
- PREFILLING = 1
- DECODING = 2
- FINISHED = 3
- FAILED = 4
- @dataclass
- class GenerationOutput:
- """Tracks the output of a generation request.
- Attributes:
- request_id (str): The ID of the generation request.
- prompt_ids (list[int]): The IDs of the prompt tokens.
- generated_tokens (list[int]): The generated tokens.
- logprobs (list[float]): The log probabilities of the generated tokens.
- error (Optional[str]): Any error message associated with the request. When None, the request was successful.
- status (RequestStatus): The status of the request.
- created_time (float): The time the request was created.
- lifespan (tuple[float, float]): The time the request was no longer pending and the time the request finished.
- """
- request_id: str
- prompt_ids: list[int] = field(default_factory=list)
- generated_tokens: list[int] = field(default_factory=list)
- logprobs: list[float] = field(default_factory=list)
- error: str | None = None
- status: RequestStatus = RequestStatus.PENDING
- created_time: float = field(default_factory=time.perf_counter)
- lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
- timestamps: list[float] | None = None # Timestamps of the generated tokens
- def is_finished(self) -> bool:
- return self.status == RequestStatus.FINISHED
- @dataclass
- class RequestState:
- """Tracks the state of a generation request through its lifecycle.
- Attributes:
- request_id (str): The ID of the generation request.
- initial_tokens (list[int]): The initial prompt tokens.
- num_children (int): The number of children requests
- full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
- prompt_ids (list[int] | None): The tokens IDs currently being processed.
- remaining_prompt_ids (list[int]): The initial tokens IDs remaining to be processed.
- static_outputs (list[int]): The generated tokens.
- allocated_blocks (int): The number of blocks allocated to the request.
- position_offset (int): The current position in the sequence for position_ids.
- status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
- SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
- max_new_tokens (int | None): The maximum number of new tokens to generate.
- eos_token_id (None | int | list[int]): The ID(s) of the end-of-sequence tokens. Only used in post-init.
- _eos_token_ids (set[int]): The IDs of the end-of-sequence tokens, formatted as a set.
- streaming (bool): Whether to stream tokens as they're generated
- created_time (float): The time the request was created.
- error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
- """
- # Required fields
- request_id: str
- initial_tokens: list[int] # Initial prompt tokens # TODO: rename this as prefill tokens
- # Optional fields
- record_timestamps: bool = False # Whether to record timestamps for the generated tokens
- num_children: int = 0 # Number of children requests
- # Internal fields
- tokens_to_process: list[int] = field(default_factory=list) # Tokens IDs currently being processed
- remaining_prefill_tokens: list[int] = field(
- default_factory=list
- ) # Initial tokens left to process (initialized in __post_init__)
- generated_tokens: list[int] = field(default_factory=list) # Generated tokens
- logprobs: list[float] = field(default_factory=list) # Log probabilities of the generated tokens
- allocated_blocks: int = 0 # Number of blocks allocated to the request
- position_offset: int = 0 # Current position in the sequence for position_ids
- _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
- max_new_tokens: int | None = 20 # Maximum number of new tokens to generate. None means no limit. Default to 20.
- eos_token_id: int | list[int] | None = None # ID(s) of the end-of-sequence tokens. Only used in post-init.
- _eos_token_ids: set[int] = field(default_factory=set) # IDs of the end-of-sequence tokens, formatted as a set
- streaming: bool = False # Whether to stream tokens as they're generated
- created_time: float = field(default_factory=time.perf_counter) # Time the request was created
- error: str | None = None # Error message if the request failed
- lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
- _timestamps: list[float] = field(default_factory=list) # Timestamps of the generated tokens
- _true_initial_tokens: int = 0 # The true number of initial tokens, useful when soft resetting requests
- # TODO: remove the attribute above to _num_initial_tokens once initial_tokens is renamed
- _new_tokens_limit: int = 2147483647 # An int to check the max number of new tokens w/out always comparing w/ None
- def __post_init__(self):
- # If no max length is set, we set an absurdly high value which will never be reached
- self._new_tokens_limit = 2147483647 if self.max_new_tokens is None else self.max_new_tokens
- # Keep a copy of the initial tokens to process
- self.remaining_prefill_tokens = self.initial_tokens[:]
- # Format the EOS token ID(s) as a set of ints. If there is no EOS token ID, it's an empty set
- if self.eos_token_id is None:
- pass
- # If there is a single EOS token ID, add it to the set only if the ID is valid, ie. non-negative
- elif isinstance(self.eos_token_id, int):
- if self.eos_token_id >= 0:
- self._eos_token_ids.add(self.eos_token_id)
- # If there are multiple EOS token IDs, add them to the set only if they are valid, ie. non-negative
- else:
- for token_id in self.eos_token_id:
- if token_id >= 0:
- self._eos_token_ids.add(token_id)
- @property
- def status(self) -> RequestStatus:
- return self._status
- @status.setter
- def status(self, value: RequestStatus):
- if self._status == RequestStatus.PENDING:
- self.lifespan = (time.perf_counter(), -1)
- elif value == RequestStatus.FINISHED:
- self.lifespan = (self.lifespan[0], time.perf_counter())
- self.log_end_of_request()
- self._status = value
- @property
- def timestamps(self) -> list[float] | None:
- return self._timestamps if self.record_timestamps else None
- def log_end_of_request(self):
- prefill_len = len(self.initial_tokens)
- decode_len = self.generated_len()
- start_time = self.lifespan[0] - self.created_time
- end_time = self.lifespan[1] - self.created_time
- logger.info(
- f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
- )
- def current_len(self) -> int:
- """Get the current length of the sequence (prompt + generated tokens)."""
- return self.position_offset
- def generated_len(self) -> int:
- """Get the number of tokens generated so far."""
- return len(self.generated_tokens)
- # TODO: this logic seems one token off, check it out
- @traced
- def update_and_check_completion(self, token_id: int, logprob: float | None) -> bool:
- """Update the request with a newly generated token (and optional log probability of the token) and check for
- completion. Returns True if the request is now complete, False otherwise."""
- # Only update if we're in decoding state # TODO: seems useless (always true) -- remove this
- if self.status != RequestStatus.DECODING:
- return False
- # If we're recording timestamps, add timestamp to the list
- if self.record_timestamps:
- self._timestamps.append(time.perf_counter())
- # Stop if we reached an EOS token
- is_eos = token_id in self._eos_token_ids
- current_len = self.generated_len()
- # Replace the temporary token if we're not finishing due to max length
- # (EOS tokens should still be added to the output)
- if is_eos or (current_len < self._new_tokens_limit):
- self.generated_tokens.append(token_id)
- self.tokens_to_process = [token_id] # this works for 2 levels of pipelines, but not sure for more
- current_len += 1
- if logprob is not None:
- self.logprobs.append(logprob)
- else:
- logger.warning(f"Request {self.request_id} generated a useless token: {token_id}")
- if is_eos or current_len >= self._new_tokens_limit:
- self.status = RequestStatus.FINISHED
- return True
- return False # We still need to process more tokens
- def __repr__(self):
- msg = [
- f"request_id={self.request_id}",
- f"status={self._status}",
- f"out_tokens={self.generated_len()}",
- f"query_length={len(self.tokens_to_process)}",
- f"remaining_tokens={len(self.remaining_prefill_tokens)}",
- f"kv_length={self.position_offset}",
- f"full_prompt_length={len(self.initial_tokens)}",
- f"allocated_blocks={self.allocated_blocks}",
- f"generated_tokens={self.generated_tokens}",
- ]
- return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"
- def to_generation_output(self):
- """Convert the request state to a GenerationOutput object."""
- if self._true_initial_tokens:
- self.generated_tokens = self.initial_tokens[self._true_initial_tokens :] + self.generated_tokens
- self.initial_tokens = self.initial_tokens[: self._true_initial_tokens]
- return GenerationOutput(
- request_id=self.request_id,
- prompt_ids=self.initial_tokens,
- generated_tokens=self.generated_tokens,
- logprobs=self.logprobs,
- error=self.error,
- status=self.status,
- created_time=self.created_time,
- lifespan=self.lifespan,
- timestamps=self.timestamps,
- )
- def fork(self, new_request_id: str) -> "RequestState":
- """Fork the request into a new request with the same state except for request_id, created_time and lifespan."""
- t = time.perf_counter()
- new_request = RequestState(
- request_id=new_request_id,
- initial_tokens=self.initial_tokens,
- num_children=self.num_children,
- tokens_to_process=self.tokens_to_process[:],
- generated_tokens=self.generated_tokens[:],
- logprobs=self.logprobs[:],
- allocated_blocks=self.allocated_blocks,
- position_offset=self.position_offset,
- _status=self.status,
- max_new_tokens=self.max_new_tokens,
- eos_token_id=self.eos_token_id,
- streaming=self.streaming,
- created_time=t,
- lifespan=(t, -1),
- _timestamps=[],
- error=self.error,
- record_timestamps=self.record_timestamps,
- )
- # Modified by __post_init__
- new_request.remaining_prefill_tokens = self.remaining_prefill_tokens[:]
- return new_request
- def create_equivalent_initial_request(self) -> "RequestState":
- """Creates an equivalent new request by removing the generated tokens and adding them to the initial prompt. The
- created request has THE SAME request_id. Notably, we can retrieve the original request from the created one with
- the _true_initial_tokens attribute. The logprobs of the generated tokens are kept in the new request."""
- max_new_tokens = None if self.max_new_tokens is None else (self.max_new_tokens - len(self.generated_tokens))
- new_state = RequestState(
- request_id=self.request_id,
- initial_tokens=self.initial_tokens + self.generated_tokens,
- logprobs=self.logprobs[:],
- num_children=self.num_children,
- record_timestamps=self.record_timestamps,
- max_new_tokens=max_new_tokens,
- eos_token_id=self.eos_token_id,
- streaming=self.streaming,
- )
- # If the request has been soft reset once already, this stays the same
- if self._true_initial_tokens:
- new_state._true_initial_tokens = self._true_initial_tokens
- # Otherwise, we set the true initial tokens to the number of initial tokens
- else:
- new_state._true_initial_tokens = len(self.initial_tokens)
- return new_state
- class FutureRequestState:
- """Tracks the current state of a request and the relevant information to update it."""
- # This makes instantiating this class faster
- __slots__ = ("state", "has_new_token", "complete_blocks")
- def __init__(self, state: RequestState, has_new_token: bool, complete_blocks: int) -> None:
- self.state = state
- self.has_new_token = has_new_token
- self.complete_blocks = complete_blocks
|