model_manager.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Model loading, caching, and lifecycle management.
  16. """
  17. import asyncio
  18. import gc
  19. import json
  20. import threading
  21. from collections.abc import Callable
  22. from functools import lru_cache
  23. from typing import TYPE_CHECKING
  24. from huggingface_hub import scan_cache_dir
  25. from tqdm import tqdm
  26. import transformers
  27. from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
  28. from ...utils import logging
  29. from .utils import Modality, make_progress_tqdm_class, reset_torch_cache
  30. if TYPE_CHECKING:
  31. from transformers import PreTrainedModel, PreTrainedTokenizerFast, ProcessorMixin
  32. logger = logging.get_logger(__name__)
  33. class TimedModel:
  34. """Wraps a model + processor and auto-unloads them after a period of inactivity.
  35. Args:
  36. model: The loaded model.
  37. timeout_seconds: Seconds of inactivity before auto-unload. Use -1 to disable.
  38. processor: The associated processor or tokenizer.
  39. on_unload: Optional callback invoked after the model is unloaded from memory.
  40. """
  41. def __init__(
  42. self,
  43. model: "PreTrainedModel",
  44. timeout_seconds: int,
  45. processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None,
  46. on_unload: "Callable | None" = None,
  47. ):
  48. self.model = model
  49. self._name_or_path = str(model.name_or_path)
  50. self.processor = processor
  51. self.timeout_seconds = timeout_seconds
  52. self._on_unload = on_unload
  53. self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached)
  54. self._timer.start()
  55. def reset_timer(self) -> None:
  56. """Reset the inactivity timer (called on each request)."""
  57. self._timer.cancel()
  58. self._timer = threading.Timer(self.timeout_seconds, self._timeout_reached)
  59. self._timer.start()
  60. def delete_model(self) -> None:
  61. """Delete the model and processor, free GPU memory."""
  62. if hasattr(self, "model") and self.model is not None:
  63. del self.model
  64. del self.processor
  65. self.model = None
  66. self.processor = None
  67. gc.collect()
  68. reset_torch_cache()
  69. self._timer.cancel()
  70. if self._on_unload is not None:
  71. self._on_unload()
  72. def _timeout_reached(self) -> None:
  73. if self.timeout_seconds > 0:
  74. self.delete_model()
  75. logger.info(f"{self._name_or_path} was removed from memory after {self.timeout_seconds}s of inactivity")
  76. class ModelManager:
  77. """Loads, caches, and manages the lifecycle of models.
  78. Handlers receive a reference to this and call `load_model_and_processor()`
  79. to get a model ready for inference.
  80. Args:
  81. device: Device to place models on (e.g. "auto", "cuda", "cpu").
  82. dtype: Torch dtype override. "auto" derives from model weights.
  83. trust_remote_code: Whether to trust remote code when loading models.
  84. attn_implementation: Attention implementation override (e.g. "flash_attention_2").
  85. quantization: Quantization method ("bnb-4bit" or "bnb-8bit").
  86. model_timeout: Seconds before an idle model is unloaded. -1 disables.
  87. force_model: If set, preload this model at init time.
  88. """
  89. def __init__(
  90. self,
  91. device: str = "auto",
  92. dtype: str | None = "auto",
  93. trust_remote_code: bool = False,
  94. attn_implementation: str | None = None,
  95. quantization: str | None = None,
  96. model_timeout: int = 300,
  97. force_model: str | None = None,
  98. ):
  99. self.loaded_models: dict[str, TimedModel] = {}
  100. # Thread-safety for concurrent load_model_and_processor calls
  101. self._model_locks: dict[str, threading.Lock] = {}
  102. self._model_locks_guard = threading.Lock()
  103. # Tracks in-flight loads for fan-out to multiple SSE subscribers (used by load_model_streaming)
  104. self._loading_subscribers: dict[str, list[asyncio.Queue[str | None]]] = {}
  105. self._loading_tasks: dict[str, asyncio.Task] = {}
  106. # Convert numeric device strings (e.g. "0") to int so device_map works correctly
  107. self.device = int(device) if device.isdigit() else device
  108. self.dtype = self._resolve_dtype(dtype)
  109. self.trust_remote_code = trust_remote_code
  110. self.attn_implementation = attn_implementation
  111. self.quantization = quantization
  112. self.model_timeout = model_timeout
  113. self.force_model = force_model
  114. self._validate_args()
  115. # Preloaded models should never be auto-unloaded
  116. if force_model is not None:
  117. self.model_timeout = -1
  118. # Preload the forced model after all state is initialized
  119. if force_model is not None:
  120. self.load_model_and_processor(self.process_model_name(force_model))
  121. @staticmethod
  122. def _resolve_dtype(dtype: str | None):
  123. import torch
  124. if dtype in ("auto", None):
  125. return dtype
  126. resolved = getattr(torch, dtype, None)
  127. if not isinstance(resolved, torch.dtype):
  128. raise ValueError(
  129. f"Unsupported dtype: '{dtype}'. Must be 'auto' or a valid torch dtype (e.g. 'float16', 'bfloat16')."
  130. )
  131. return resolved
  132. def _validate_args(self):
  133. if self.quantization is not None and self.quantization not in ("bnb-4bit", "bnb-8bit"):
  134. raise ValueError(
  135. f"Unsupported quantization method: '{self.quantization}'. Must be 'bnb-4bit' or 'bnb-8bit'."
  136. )
  137. VALID_ATTN_IMPLEMENTATIONS = {"eager", "sdpa", "flash_attention_2", "flash_attention_3", "flex_attention"}
  138. is_kernels_community = self.attn_implementation is not None and self.attn_implementation.startswith(
  139. "kernels-community/"
  140. )
  141. if (
  142. self.attn_implementation is not None
  143. and not is_kernels_community
  144. and self.attn_implementation not in VALID_ATTN_IMPLEMENTATIONS
  145. ):
  146. raise ValueError(
  147. f"Unsupported attention implementation: '{self.attn_implementation}'. "
  148. f"Must be one of {VALID_ATTN_IMPLEMENTATIONS} or a kernels-community kernel (e.g. 'kernels-community/flash-attn2')."
  149. )
  150. @staticmethod
  151. def process_model_name(model_id: str) -> str:
  152. """Canonicalize to `'model_id@revision'` format. Defaults to `@main`."""
  153. if "@" in model_id:
  154. return model_id
  155. return f"{model_id}@main"
  156. def get_quantization_config(self) -> BitsAndBytesConfig | None:
  157. """Return a BitsAndBytesConfig based on the `quantization` setting, or None."""
  158. if self.quantization == "bnb-4bit":
  159. return BitsAndBytesConfig(
  160. load_in_4bit=True,
  161. bnb_4bit_quant_type="nf4",
  162. bnb_4bit_use_double_quant=True,
  163. )
  164. elif self.quantization == "bnb-8bit":
  165. return BitsAndBytesConfig(load_in_8bit=True)
  166. return None
  167. def _load_processor(self, model_id_and_revision: str) -> "ProcessorMixin | PreTrainedTokenizerFast":
  168. """Load a processor for the given model.
  169. Args:
  170. model_id_and_revision: Model ID in ``'model_id@revision'`` format.
  171. """
  172. from transformers import AutoProcessor
  173. model_id, revision = model_id_and_revision.split("@", 1)
  174. return AutoProcessor.from_pretrained(model_id, revision=revision, trust_remote_code=self.trust_remote_code)
  175. def _load_model(
  176. self, model_id_and_revision: str, tqdm_class: type | None = None, progress_callback: Callable | None = None
  177. ) -> "PreTrainedModel":
  178. """Load a model.
  179. Args:
  180. model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format.
  181. tqdm_class (*optional*): tqdm subclass for progress bars during ``from_pretrained``.
  182. progress_callback (`Callable`, *optional*): Called with progress dicts during loading.
  183. Returns:
  184. `PreTrainedModel`: The loaded model.
  185. """
  186. from transformers import AutoConfig
  187. model_id, revision = model_id_and_revision.split("@", 1)
  188. model_kwargs = {
  189. "revision": revision,
  190. "attn_implementation": self.attn_implementation,
  191. "dtype": self.dtype,
  192. "device_map": self.device,
  193. "trust_remote_code": self.trust_remote_code,
  194. "quantization_config": self.get_quantization_config(),
  195. "tqdm_class": tqdm_class,
  196. }
  197. if progress_callback is not None:
  198. progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "config"})
  199. config = AutoConfig.from_pretrained(model_id, **model_kwargs)
  200. architecture = getattr(transformers, config.architectures[0])
  201. return architecture.from_pretrained(model_id, **model_kwargs)
  202. def load_model_and_processor(
  203. self,
  204. model_id_and_revision: str,
  205. progress_callback: Callable | None = None,
  206. tqdm_class: type | None = None,
  207. ) -> "tuple[PreTrainedModel, ProcessorMixin | PreTrainedTokenizerFast]":
  208. """Load a model (or return it from cache), resetting its inactivity timer.
  209. Args:
  210. model_id_and_revision: Model ID in ``'model_id@revision'`` format.
  211. progress_callback: If provided, called with dicts like
  212. ``{"status": "loading", "model": ..., "stage": ...}`` during loading.
  213. tqdm_class: Optional tqdm subclass for progress bars during ``from_pretrained``.
  214. """
  215. # Per-model lock prevents duplicate loads when concurrent requests arrive
  216. with self._model_locks_guard:
  217. lock = self._model_locks.setdefault(model_id_and_revision, threading.Lock())
  218. with lock:
  219. if model_id_and_revision not in self.loaded_models:
  220. logger.warning(f"Loading {model_id_and_revision}")
  221. if progress_callback is not None:
  222. progress_callback({"status": "loading", "model": model_id_and_revision, "stage": "processor"})
  223. processor = self._load_processor(model_id_and_revision)
  224. model = self._load_model(
  225. model_id_and_revision, tqdm_class=tqdm_class, progress_callback=progress_callback
  226. )
  227. self.loaded_models[model_id_and_revision] = TimedModel(
  228. model,
  229. timeout_seconds=self.model_timeout,
  230. processor=processor,
  231. on_unload=lambda key=model_id_and_revision: self.loaded_models.pop(key, None),
  232. )
  233. if progress_callback is not None:
  234. progress_callback({"status": "ready", "model": model_id_and_revision, "cached": False})
  235. else:
  236. self.loaded_models[model_id_and_revision].reset_timer()
  237. model = self.loaded_models[model_id_and_revision].model
  238. processor = self.loaded_models[model_id_and_revision].processor
  239. if progress_callback is not None:
  240. progress_callback({"status": "ready", "model": model_id_and_revision, "cached": True})
  241. return model, processor
  242. async def load_model_streaming(self, model_id_and_revision: str):
  243. """Load a model and stream progress as SSE events.
  244. Handles three cases:
  245. 1. Model already cached → single ``ready`` event
  246. 2. Load already in progress → join existing subscriber stream
  247. 3. First request → start loading, broadcast to all subscribers
  248. Args:
  249. model_id_and_revision (`str`): Model ID in ``'model_id@revision'`` format.
  250. Yields:
  251. `str`: SSE ``data: ...`` lines with progress updates.
  252. """
  253. mid = model_id_and_revision
  254. queue: asyncio.Queue[str | None] = asyncio.Queue()
  255. # Case 1: already cached
  256. if mid in self.loaded_models:
  257. self.loaded_models[mid].reset_timer()
  258. yield f"data: {json.dumps({'status': 'ready', 'model': mid, 'cached': True})}\n\n"
  259. return
  260. # Case 2: load in progress — join existing subscribers
  261. if mid in self._loading_tasks:
  262. self._loading_subscribers[mid].append(queue)
  263. while True:
  264. item = await queue.get()
  265. if item is None:
  266. break
  267. yield item
  268. return
  269. # Case 3: first request — start the load
  270. self._loading_subscribers[mid] = [queue]
  271. loop = asyncio.get_running_loop()
  272. def enqueue(payload: dict):
  273. msg = f"data: {json.dumps(payload)}\n\n"
  274. def broadcast():
  275. for q in self._loading_subscribers.get(mid, []):
  276. q.put_nowait(msg)
  277. loop.call_soon_threadsafe(broadcast)
  278. tqdm_class = make_progress_tqdm_class(enqueue, mid)
  279. def _tqdm_hook(factory, args, kwargs):
  280. return tqdm_class(*args, **kwargs)
  281. async def run_load():
  282. try:
  283. # Install a global tqdm hook so the "Loading weights" bar in
  284. # core_model_loading.py (which uses logging.tqdm) routes through
  285. # our ProgressTqdm. The tqdm_class kwarg only covers download bars.
  286. previous_hook = logging.set_tqdm_hook(_tqdm_hook)
  287. try:
  288. await asyncio.to_thread(
  289. self.load_model_and_processor,
  290. mid,
  291. progress_callback=enqueue,
  292. tqdm_class=tqdm_class,
  293. )
  294. finally:
  295. logging.set_tqdm_hook(previous_hook)
  296. except Exception as e:
  297. logger.error(f"Failed to load {mid}: {e}", exc_info=True)
  298. enqueue({"status": "error", "model": mid, "message": str(e)})
  299. finally:
  300. def _send_sentinel():
  301. for q in self._loading_subscribers.pop(mid, []):
  302. q.put_nowait(None)
  303. self._loading_tasks.pop(mid, None)
  304. loop.call_soon_threadsafe(_send_sentinel)
  305. self._loading_tasks[mid] = asyncio.create_task(run_load())
  306. while True:
  307. item = await queue.get()
  308. if item is None:
  309. break
  310. yield item
  311. def shutdown(self) -> None:
  312. """Delete all loaded models and free resources."""
  313. for timed in list(self.loaded_models.values()):
  314. timed.delete_model()
  315. @staticmethod
  316. def get_model_modality(
  317. model: "PreTrainedModel", processor: "ProcessorMixin | PreTrainedTokenizerFast | None" = None
  318. ) -> Modality:
  319. """Detect whether a model is an LLM or VLM based on its architecture.
  320. Args:
  321. model (`PreTrainedModel`): The loaded model.
  322. processor (`ProcessorMixin | PreTrainedTokenizerFast`, *optional*):
  323. If a plain tokenizer (not a multi-modal processor), short-circuits to LLM.
  324. Returns:
  325. `Modality`: The detected modality (``Modality.LLM`` or ``Modality.VLM``).
  326. """
  327. if processor is not None and isinstance(processor, PreTrainedTokenizerBase):
  328. return Modality.LLM
  329. from transformers.models.auto.modeling_auto import (
  330. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  331. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
  332. )
  333. model_classname = model.__class__.__name__
  334. if model_classname in MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values():
  335. return Modality.VLM
  336. elif model_classname in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
  337. return Modality.LLM
  338. else:
  339. raise ValueError(f"Unknown modality for: {model_classname}")
  340. @staticmethod
  341. @lru_cache
  342. def get_gen_models(cache_dir: str | None = None) -> list[dict]:
  343. """List generative models (LLMs and VLMs) available in the HuggingFace cache.
  344. Args:
  345. cache_dir (`str`, *optional*): Path to the HuggingFace cache directory.
  346. Defaults to the standard cache location.
  347. Returns:
  348. `list[dict]`: OpenAI-compatible model list entries with ``id``, ``object``, etc.
  349. """
  350. from transformers.models.auto.modeling_auto import (
  351. MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
  352. MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
  353. )
  354. generative_models = []
  355. logger.warning("Scanning the cache directory for LLMs and VLMs.")
  356. for repo in tqdm(scan_cache_dir(cache_dir).repos):
  357. if repo.repo_type != "model":
  358. continue
  359. for ref, revision_info in repo.refs.items():
  360. config_path = next((f.file_path for f in revision_info.files if f.file_name == "config.json"), None)
  361. if not config_path:
  362. continue
  363. config = json.loads(config_path.open().read())
  364. if not (isinstance(config, dict) and "architectures" in config):
  365. continue
  366. architectures = config["architectures"]
  367. llms = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()
  368. vlms = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
  369. if any(arch for arch in architectures if arch in [*llms, *vlms]):
  370. author = repo.repo_id.split("/") if "/" in repo.repo_id else ""
  371. repo_handle = repo.repo_id + (f"@{ref}" if ref != "main" else "")
  372. generative_models.append(
  373. {
  374. "owned_by": author,
  375. "id": repo_handle,
  376. "object": "model",
  377. "created": repo.last_modified,
  378. }
  379. )
  380. return generative_models