executorch.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
  5. # the License. You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
  10. # an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
  11. # specific language governing permissions and limitations under the License.
  12. import logging
  13. import torch
  14. from ..cache_utils import (
  15. DynamicCache,
  16. DynamicLayer,
  17. DynamicSlidingWindowLayer,
  18. EncoderDecoderCache,
  19. StaticCache,
  20. StaticLayer,
  21. StaticSlidingWindowLayer,
  22. )
  23. from ..generation.configuration_utils import GenerationConfig
  24. from ..modeling_utils import PreTrainedModel
  25. from ..pytorch_utils import (
  26. is_torch_greater_or_equal,
  27. is_torch_greater_or_equal_than_2_6,
  28. )
  29. class TorchExportableModuleForVLM:
  30. """
  31. A wrapper class for exporting Vision-Language Models (VLMs) like SmolVLM2 for ExecuTorch.
  32. This class handles the export of three main components:
  33. 1. Vision encoder (processes images to visual features)
  34. 2. Connector/projector (maps visual features to text embedding space)
  35. 3. Text decoder (generates text from combined visual and text tokens)
  36. """
  37. def __init__(self, model, max_batch_size: int = 1, max_cache_len: int = 1024):
  38. """
  39. Initialize the exportable VLM module.
  40. Args:
  41. model: The VLM (e.g. SmolVLM) model instance
  42. max_batch_size: Maximum batch size. Always 1 for ExecuTorch
  43. max_cache_len: Maximum cache length for text generation
  44. """
  45. self.model = model
  46. self.max_batch_size = max_batch_size
  47. self.max_cache_len = max_cache_len
  48. self.config = model.config
  49. # Extract individual components
  50. self.vision_encoder = model.model.vision_model
  51. self.connector = model.model.connector
  52. self.text_decoder = model.model.text_model
  53. # Store exported programs
  54. self.exported_vision_encoder = None
  55. self.exported_connector = None
  56. self.exported_text_decoder = None
  57. def export_vision_encoder(self):
  58. """Export the vision encoder component."""
  59. self.vision_encoder.eval()
  60. # Create example input
  61. pixel_values = torch.randn(1, 3, 384, 384, dtype=torch.float32)
  62. # Define dynamic shapes
  63. dynamic_shapes = {
  64. "pixel_values": {
  65. 2: torch.export.Dim.AUTO,
  66. 3: torch.export.Dim.AUTO,
  67. }
  68. }
  69. self.exported_vision_encoder = torch.export.export(
  70. self.vision_encoder,
  71. args=(pixel_values,),
  72. dynamic_shapes=dynamic_shapes,
  73. strict=False,
  74. )
  75. return self.exported_vision_encoder
  76. def export_connector(self):
  77. """Export the connector component."""
  78. self.connector.eval()
  79. # Vision encoder output shape: [batch_size, num_patches, vision_hidden_size]
  80. vision_hidden_size = self.config.vision_config.hidden_size
  81. image_size = self.config.vision_config.image_size
  82. patch_size = self.config.vision_config.patch_size
  83. patches_per_dim = image_size // patch_size
  84. num_patches = patches_per_dim * patches_per_dim
  85. image_hidden_states = torch.randn(1, num_patches, vision_hidden_size, dtype=torch.float32)
  86. # Define dynamic shapes - static batch_size=1, dynamic num_patches
  87. dynamic_shapes = {"image_hidden_states": {1: torch.export.Dim.AUTO}}
  88. # Export the connector using torch.export
  89. self.exported_connector = torch.export.export(
  90. self.connector,
  91. args=(image_hidden_states,),
  92. dynamic_shapes=dynamic_shapes,
  93. strict=False,
  94. )
  95. return self.exported_connector
  96. def export_text_decoder(self):
  97. """Export the text decoder component."""
  98. # Create text decoder exportable wrapper
  99. self.exportable_text_decoder = TorchExportableModuleForDecoderOnlyLM(model=self.text_decoder)
  100. # Use the existing text decoder exportable wrapper
  101. seq_length = 3
  102. input_ids = torch.zeros((1, seq_length), dtype=torch.long)
  103. cache_position = torch.arange(seq_length, dtype=torch.long)
  104. max_seq_length = min(self.max_cache_len, self.config.text_config.max_position_embeddings)
  105. seq_len_dim = torch.export.Dim("seq_length_dim", max=max_seq_length - 1)
  106. dynamic_shapes = {
  107. "input_ids": {1: seq_len_dim},
  108. "cache_position": {0: seq_len_dim},
  109. }
  110. self.exported_text_decoder = self.exportable_text_decoder.export(
  111. input_ids=input_ids,
  112. cache_position=cache_position,
  113. dynamic_shapes=dynamic_shapes,
  114. strict=False,
  115. )
  116. return self.exported_text_decoder
  117. def export(self, **kwargs):
  118. """Export all components of the VLM model."""
  119. self.export_vision_encoder(**kwargs)
  120. self.export_connector(**kwargs)
  121. self.export_text_decoder(**kwargs)
  122. return {
  123. "vision_encoder": self.exported_vision_encoder,
  124. "connector": self.exported_connector,
  125. "text_decoder": self.exported_text_decoder,
  126. }
  127. def forward(self, pixel_values, input_ids, cache_position):
  128. """
  129. Simplified forward pass for inference with guaranteed non-null input_ids and cache_position.
  130. Args:
  131. pixel_values: Input images [1, channels, height, width] (optional)
  132. input_ids: Text token IDs [1, seq_len] (required - won't be None)
  133. cache_position: Cache positions [seq_len] (required - won't be None)
  134. Returns:
  135. Output with logits for text generation
  136. """
  137. def generate(
  138. self, pixel_values=None, input_ids=None, max_new_tokens=50, do_sample=False, temperature=1.0, **kwargs
  139. ):
  140. """
  141. Simplified generate method with guaranteed non-null input_ids.
  142. Args:
  143. pixel_values: Input images [1, channels, height, width] (optional)
  144. input_ids: Initial text tokens [1, seq_len] (required - won't be None)
  145. max_new_tokens: Maximum number of tokens to generate
  146. do_sample: Whether to use sampling or greedy decoding
  147. temperature: Temperature for sampling
  148. Returns:
  149. Generated sequences
  150. """
  151. class TorchExportableModuleForDecoderOnlyLM(torch.nn.Module):
  152. """
  153. A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
  154. specifically for decoder-only LM with cache. This module ensures that the
  155. exported model is compatible with further lowering and execution in `ExecuTorch`.
  156. """
  157. def __init__(
  158. self,
  159. model: PreTrainedModel,
  160. batch_size: int | None = None,
  161. max_cache_len: int | None = None,
  162. device: torch.device | None = None,
  163. ) -> None:
  164. """
  165. Initializes the exportable module.
  166. Args:
  167. model (`PreTrainedModel`): The pretrained model to wrap.
  168. Raises:
  169. ValueError: If the model is configured with a unsupported cache implementation.
  170. """
  171. super().__init__()
  172. config = model.config.get_text_config()
  173. if not hasattr(config, "use_cache") or config.use_cache is False:
  174. raise ValueError("The model must have caching enabled to be performant.")
  175. if hasattr(config, "layer_types") and getattr(config, "sliding_window", None) is not None:
  176. self.model = TorchExportableModuleWithHybridCache(model, batch_size, max_cache_len, device)
  177. else:
  178. # If `layer_types` is not specified explicitly in the config or `sliding_window` is null,
  179. # there is only 1 type of layers, so export will use `StaticCache` by default.
  180. logging.info(
  181. "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
  182. )
  183. self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
  184. def forward(
  185. self,
  186. input_ids: torch.Tensor | None = None,
  187. inputs_embeds: torch.Tensor | None = None,
  188. cache_position: torch.Tensor | None = None,
  189. ) -> torch.Tensor:
  190. """
  191. Forward pass of the module, which is compatible with the ExecuTorch llm runner.
  192. Args:
  193. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
  194. inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
  195. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
  196. Returns:
  197. torch.Tensor: Logits output from the model.
  198. """
  199. return self.model.forward(input_ids=input_ids, inputs_embeds=inputs_embeds)
  200. def export(
  201. self,
  202. input_ids: torch.Tensor | None = None,
  203. inputs_embeds: torch.Tensor | None = None,
  204. cache_position: torch.Tensor | None = None,
  205. dynamic_shapes: dict | None = None,
  206. strict: bool | None = None,
  207. ) -> torch.export.ExportedProgram:
  208. """
  209. Export the wrapped module using `torch.export`.
  210. Args:
  211. input_ids (`Optional[torch.Tensor]`):
  212. Tensor representing current input token id to the module. Must specify either this or inputs_embeds.
  213. inputs_embeds (`Optional[torch.Tensor]`):
  214. Tensor representing current input embeddings to the module. Must specify either this or input_ids.
  215. cache_position (`Optional[torch.Tensor]`):
  216. Tensor representing current input position in the cache. If not provided, a default tensor will be used.
  217. dynamic_shapes (`Optional[dict]`):
  218. Dynamic shapes to use for export if specified.
  219. strict(`Optional[bool]`):
  220. Flag to instruct `torch.export` to use `torchdynamo`.
  221. Returns:
  222. torch.export.ExportedProgram: The exported program that can be used for inference.
  223. Examples:
  224. Export with input_ids:
  225. ```python
  226. # Prepare inputs
  227. input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long, device=model.device)
  228. cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long, device=model.device)
  229. # Export
  230. exported = exportable_module.export(
  231. input_ids=input_ids,
  232. cache_position=cache_position
  233. )
  234. ```
  235. Export with inputs_embeds:
  236. ```python
  237. # Prepare embeddings
  238. inputs_embeds = torch.randn(1, 3, 768, device=model.device) # batch_size=1, seq_len=3, hidden_size=768
  239. cache_position = torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model.device)
  240. # Export
  241. exported = exportable_module.export(
  242. inputs_embeds=inputs_embeds,
  243. cache_position=cache_position
  244. )
  245. ```
  246. """
  247. if not (input_ids is None) ^ (inputs_embeds is None):
  248. raise ValueError("Need to specify either input_ids or inputs_embeds.")
  249. if hasattr(self.model, "base_model_prefix"):
  250. base = getattr(self.model, self.model.base_model_prefix, self.model)
  251. model_device = base.device
  252. elif hasattr(self.model, "model"):
  253. model_device = self.model.model.device
  254. else:
  255. model_device = "cpu"
  256. logging.warning(
  257. "TorchExportableModuleForDecoderOnlyLM.export Can't infer device from the model. Set to CPU by default."
  258. )
  259. if input_ids is not None:
  260. input_kwargs = {
  261. "input_ids": input_ids,
  262. "cache_position": cache_position
  263. if cache_position is not None
  264. else torch.arange(input_ids.shape[-1], dtype=torch.long, device=model_device),
  265. }
  266. else: # inputs_embeds
  267. input_kwargs = {
  268. "inputs_embeds": inputs_embeds,
  269. "cache_position": cache_position
  270. if cache_position is not None
  271. else torch.arange(inputs_embeds.shape[1], dtype=torch.long, device=model_device),
  272. }
  273. exported_program = torch.export.export(
  274. self.model,
  275. args=(),
  276. kwargs=input_kwargs,
  277. dynamic_shapes=dynamic_shapes,
  278. strict=strict if strict is not None else True,
  279. )
  280. return exported_program
  281. @staticmethod
  282. def generate(
  283. exported_program: torch.export.ExportedProgram,
  284. tokenizer,
  285. prompt: str,
  286. max_new_tokens: int = 20,
  287. do_sample: bool = False,
  288. temperature: float = 1.0,
  289. top_k: int = 50,
  290. top_p: float = 1.0,
  291. device: str = "cpu",
  292. ) -> str:
  293. """
  294. Generate a sequence of tokens using an exported program.
  295. Args:
  296. exported_program (`torch.export.ExportedProgram`): The exported model being used for generate.
  297. tokenizer: The tokenizer to use.
  298. prompt (str): The input prompt.
  299. max_new_tokens (int): Maximum number of new tokens to generate.
  300. do_sample (bool): Whether to use sampling or greedy decoding.
  301. temperature (float): The temperature for sampling.
  302. top_k (int): The number of highest probability tokens to keep for top-k sampling.
  303. top_p (float): The cumulative probability for nucleus sampling.
  304. device (str): The device to use.
  305. Returns:
  306. str: The generated text.
  307. """
  308. # Get the module from the exported program
  309. exported_module = exported_program.module()
  310. # Tokenize the prompt
  311. input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
  312. # Initialize with the prompt
  313. generated_ids = input_ids.clone()
  314. # Process the prompt tokens first
  315. curr_position = 0
  316. for i in range(input_ids.shape[1]):
  317. # Process one token at a time
  318. curr_input_ids = input_ids[:, i : i + 1]
  319. curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
  320. # Forward pass
  321. _ = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
  322. curr_position += 1
  323. # Generate new tokens
  324. for _ in range(max_new_tokens):
  325. # Get the last token as input
  326. curr_input_ids = generated_ids[:, -1:]
  327. curr_cache_position = torch.tensor([curr_position], dtype=torch.long, device=device)
  328. # Forward pass to get next token logits
  329. outputs = exported_module(input_ids=curr_input_ids, cache_position=curr_cache_position)
  330. # Get the next token ID
  331. if do_sample:
  332. # Apply temperature
  333. if temperature > 0:
  334. logits = outputs / temperature
  335. else:
  336. logits = outputs
  337. # Apply top-k filtering
  338. if top_k > 0:
  339. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
  340. logits[indices_to_remove] = float("-inf")
  341. # Apply top-p (nucleus) filtering
  342. if top_p < 1.0:
  343. sorted_logits, sorted_indices = torch.sort(logits, descending=True)
  344. cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
  345. # Remove tokens with cumulative probability above the threshold
  346. sorted_indices_to_remove = cumulative_probs > top_p
  347. # Shift the indices to the right to keep also the first token above the threshold
  348. sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
  349. sorted_indices_to_remove[..., 0] = 0
  350. # Scatter sorted tensors to original indexing
  351. indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
  352. logits[indices_to_remove] = float("-inf")
  353. # Sample from the filtered distribution
  354. probs = torch.softmax(logits, dim=-1)
  355. next_token_id = torch.multinomial(probs, num_samples=1)
  356. else:
  357. # Greedy decoding
  358. next_token_id = outputs.argmax(dim=-1, keepdim=True)
  359. # Ensure next_token_id has the right shape before concatenation
  360. if next_token_id.dim() > 2:
  361. next_token_id = next_token_id.squeeze(-1)
  362. # Append to the generated sequence
  363. generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
  364. curr_position += 1
  365. # Stop if we generate an EOS token
  366. if next_token_id.item() == tokenizer.eos_token_id:
  367. break
  368. # Decode the generated text
  369. return tokenizer.decode(generated_ids[0], skip_special_tokens=True)
  370. def get_head_shapes(config) -> tuple[int | list[int], int | list[int]]:
  371. """Returns a tuple `(num_heads, head_dim)` containing either 2 ints, or a list of int with the value for each
  372. layer."""
  373. # Gemma4 has different head_dim and num_heads depending on layer type
  374. if hasattr(config, "global_head_dim"):
  375. head_dim = [
  376. config.global_head_dim if layer == "full_attention" else config.head_dim
  377. for layer in config.layer_types[: -config.num_kv_shared_layers]
  378. ]
  379. num_heads = [
  380. config.num_global_key_value_heads
  381. if layer == "full_attention" and config.attention_k_eq_v
  382. else config.num_key_value_heads
  383. for layer in config.layer_types[: -config.num_kv_shared_layers]
  384. ]
  385. else:
  386. head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  387. num_heads = getattr(config, "num_key_value_heads", config.num_attention_heads)
  388. return num_heads, head_dim
  389. class TorchExportableModuleWithStaticCache(torch.nn.Module):
  390. """
  391. A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
  392. specifically for decoder-only LM to `StaticCache`. This module ensures that the
  393. exported model is compatible with further lowering and execution in `ExecuTorch`.
  394. Note:
  395. This class is specifically designed to support export process using `torch.export`
  396. in a way that ensures the model can be further lowered and run efficiently in `ExecuTorch`.
  397. """
  398. def __init__(
  399. self,
  400. model: PreTrainedModel,
  401. batch_size: int | None = None,
  402. max_cache_len: int | None = None,
  403. device: torch.device | None = None,
  404. ) -> None:
  405. """
  406. Initializes the wrapper module with the pretrained model.
  407. Args:
  408. model (`PreTrainedModel`): The pretrained model to wrap. The model must have caching
  409. enabled and use a 'static' caching implementation.
  410. batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found
  411. in `generation_config.cache_config` and otherwise we raise a ValueError.
  412. max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if
  413. not provided.
  414. device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found
  415. in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised).
  416. Raises:
  417. AssertionError: If the pretrained model does not have caching enabled or if it does
  418. not use a 'static' caching implementation in `model.generation_config`.
  419. ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`.
  420. """
  421. super().__init__()
  422. config = model.config.get_text_config()
  423. generation_config = model.generation_config
  424. # Sanity checks
  425. if generation_config is None:
  426. raise AssertionError(
  427. "The model must have a generation config to be exported with static caching. "
  428. "Please set `generation_config` in `model`."
  429. )
  430. if not generation_config.use_cache:
  431. raise AssertionError(
  432. "The model must have caching enabled to be exported with static caching. "
  433. "Please set `generation_config.use_cache=True`."
  434. )
  435. if generation_config.cache_implementation != "static":
  436. raise AssertionError(
  437. "The model must use a 'static' caching implementation to be exported with static caching. "
  438. "Please set `generation_config.cache_implementation='static'`."
  439. )
  440. cache_config = {} if generation_config.cache_config is None else generation_config.cache_config
  441. # Ensure batch_size and max_cache_len are set
  442. if batch_size is None:
  443. batch_size = cache_config.get("batch_size", None)
  444. if batch_size is None:
  445. raise ValueError("batch_size must be provided, either as an argument or in cache_config.")
  446. if max_cache_len is None:
  447. max_cache_len = cache_config.get("max_cache_len", None)
  448. if max_cache_len is None:
  449. raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.")
  450. # Infer device if not provided
  451. if device is None:
  452. device = cache_config.get("device", model.device)
  453. # Initialize the static cache
  454. self.model = model
  455. self.static_cache = StaticCache(max_cache_len=max_cache_len, config=config)
  456. # Since StaticSlidingWindow have dynamic control flow that cannot be avoided, we have to replace them here by
  457. # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
  458. for i, layer in enumerate(self.static_cache.layers):
  459. if isinstance(layer, StaticSlidingWindowLayer):
  460. self.static_cache.layers[i] = StaticLayer(max_cache_len)
  461. num_heads, head_dim = get_head_shapes(config)
  462. dtype = self.model.dtype
  463. # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
  464. self.static_cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
  465. # Register cache buffers to make them exportable
  466. for i, layer in enumerate(self.static_cache.layers):
  467. self.register_buffer(f"key_cache_{i}", layer.keys, persistent=False)
  468. self.register_buffer(f"value_cache_{i}", layer.values, persistent=False)
  469. self.register_buffer(f"cumulative_length_{i}", layer.cumulative_length, persistent=False)
  470. def forward(
  471. self,
  472. input_ids: torch.LongTensor | None = None,
  473. inputs_embeds: torch.Tensor | None = None,
  474. cache_position: torch.Tensor | None = None,
  475. ):
  476. """
  477. Forward pass of the module, which is compatible with the ExecuTorch runtime.
  478. Args:
  479. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
  480. inputs_embeds (`torch.Tensor`): Tensor representing current input embeddings to the module.
  481. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
  482. Returns:
  483. torch.Tensor: Logits output from the model.
  484. This forward adapter serves two primary purposes:
  485. 1. **Making the Model `torch.export`-Compatible**:
  486. The adapter hides unsupported objects, such as the `Cache`, from the graph inputs and outputs,
  487. enabling the model to be exportable using `torch.export` without encountering issues.
  488. 2. **Ensuring Compatibility with `ExecuTorch` runtime**:
  489. The adapter matches the model's forward signature with that in `executorch/extension/llm/runner`,
  490. ensuring that the exported model can be executed in `ExecuTorch` out-of-the-box.
  491. """
  492. # Start by resetting static cache (it's needed to be able to run several generations with the same exported program,
  493. # as otherwise it's mutated in-place indefinitely - we cannot call reset in-between the `generate` as the program was
  494. # already exported)
  495. for layer in self.static_cache.layers:
  496. layer.cumulative_length.copy_(cache_position[0:1])
  497. past_key_values = self.static_cache
  498. outs = self.model(
  499. input_ids=input_ids,
  500. inputs_embeds=inputs_embeds,
  501. attention_mask=None,
  502. past_key_values=past_key_values,
  503. use_cache=True,
  504. )
  505. if hasattr(outs, "logits"):
  506. # Returned outputs is `CausalLMOutputWithPast`
  507. return outs.logits
  508. else:
  509. # Returned the `last_hidden_state` from `BaseModelOutputWithPast`
  510. return outs.last_hidden_state
  511. @staticmethod
  512. def generate(
  513. exported_program: torch.export.ExportedProgram,
  514. prompt_token_ids: torch.Tensor,
  515. max_new_tokens: int,
  516. ) -> torch.Tensor:
  517. """
  518. Generate a sequence of tokens using an exported program.
  519. This util function is designed to test exported models by simulating the generation process.
  520. It processes the input prompt tokens sequentially (no parallel prefill).
  521. This generate function is not intended to replace the original `generate` method, and the support
  522. for leveraging the original `generate` is potentially planned!
  523. Args:
  524. exported_program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
  525. prompt_token_ids (`torch.Tensor`): Tensor representing the input prompt token IDs.
  526. max_new_tokens (`int`): Maximum number of new tokens to generate. Note that the total generation
  527. length is limited by both `max_new_tokens` and the model's cache size.
  528. Returns:
  529. torch.Tensor: A tensor containing the generated sequence of token IDs, including the original prompt tokens.
  530. """
  531. device = prompt_token_ids.device
  532. prompt_token_len = prompt_token_ids.shape[-1]
  533. max_generation_length = prompt_token_len + max_new_tokens
  534. for buffer_name, buffer in exported_program.named_buffers():
  535. if buffer_name.startswith("key_cache"):
  536. max_cache_len = buffer.shape[2]
  537. max_generation_length = min(max_generation_length, max_cache_len)
  538. break
  539. response_tokens = []
  540. for input_pos in range(min(max_generation_length, prompt_token_len)):
  541. result = exported_program.module().forward(
  542. input_ids=prompt_token_ids[:, input_pos : input_pos + 1],
  543. cache_position=torch.tensor([input_pos], dtype=torch.long, device=device),
  544. )
  545. response_tokens.append(prompt_token_ids[0][input_pos].item())
  546. current_token = torch.argmax(result[:, -1, :], dim=-1).item()
  547. response_tokens.append(current_token)
  548. while len(response_tokens) < max_generation_length:
  549. result = exported_program.module().forward(
  550. input_ids=torch.tensor([[current_token]], dtype=torch.long, device=device),
  551. cache_position=torch.tensor([len(response_tokens)], dtype=torch.long, device=device),
  552. )
  553. current_token = torch.argmax(result[:, -1, :], dim=-1).item()
  554. response_tokens.append(current_token)
  555. return torch.tensor([response_tokens], dtype=torch.long, device=device)
  556. class TorchExportableModuleWithHybridCache(torch.nn.Module):
  557. """
  558. A recipe module designed to make a `PreTrainedModel` exportable with `torch.export`,
  559. specifically for decoder-only LM to hybrid `StaticCache`. This module ensures that the
  560. exported model is compatible with further lowering and execution in `ExecuTorch`.
  561. """
  562. def __init__(
  563. self,
  564. model: PreTrainedModel,
  565. batch_size: int | None = None,
  566. max_cache_len: int | None = None,
  567. device: torch.device | None = None,
  568. ) -> None:
  569. """
  570. Initializes the exportable module.
  571. Args:
  572. model (`PreTrainedModel`): The pretrained model to wrap.
  573. batch_size (`Optional[int]`): The batch size of the model. If not provided, we check if a value can be found
  574. in `generation_config.cache_config` and otherwise we raise a ValueError.
  575. max_cache_len (`Optional[int]`): The maximum cache length for generation. Same mechanism as `batch_size` if
  576. not provided.
  577. device (`Optional[torch.device]`): The device to use. If not provided, we check if a value can be found
  578. in `generation_config.cache_config` and otherwise we use `model.device` (no error is raised).
  579. Raises:
  580. AssertionError: If the model doesn't have the expected configuration for hybrid StaticCache.
  581. ValueError: If `batch_size` or `max_cache_len` is not provided, either as an argument or in `cache_config`.
  582. """
  583. super().__init__()
  584. self.model = model
  585. config = model.config.get_text_config()
  586. generation_config = model.generation_config
  587. # Sanity checks
  588. if generation_config is None:
  589. raise AssertionError(
  590. "The model must have a generation config to be exported with static caching. "
  591. "Please set `generation_config` in `model`."
  592. )
  593. if not config.use_cache:
  594. raise AssertionError("Model must have caching enabled.")
  595. cache_config = {} if generation_config.cache_config is None else generation_config.cache_config
  596. # Ensure batch_size and max_cache_len are set
  597. if batch_size is None:
  598. batch_size = cache_config.get("batch_size", None)
  599. if batch_size is None:
  600. raise ValueError("batch_size must be provided, either as an argument or in cache_config.")
  601. if max_cache_len is None:
  602. max_cache_len = cache_config.get("max_cache_len", None)
  603. if max_cache_len is None:
  604. raise ValueError("max_cache_len must be provided, either as an argument or in cache_config.")
  605. # Infer device if not provided
  606. if device is None:
  607. device = cache_config.get("device", model.device)
  608. # Initialize the cache
  609. self.cache = StaticCache(config=config, max_cache_len=max_cache_len)
  610. # Since StaticSlidingWindow have dynamic control flow that cannot be avoided, we have to replace them here by
  611. # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
  612. for i, layer in enumerate(self.cache.layers):
  613. if isinstance(layer, StaticSlidingWindowLayer):
  614. self.cache.layers[i] = StaticLayer(max_cache_len)
  615. num_heads, head_dim = get_head_shapes(config)
  616. dtype = self.model.dtype
  617. # We need this call to initialize all the layers (otherwise it's done lazily, which is not exportable)
  618. self.cache.early_initialization(batch_size, num_heads, head_dim, dtype, device)
  619. # Register cache buffers to make them exportable
  620. for i, layer in enumerate(self.cache.layers):
  621. self.register_buffer(f"key_cache_{i}", layer.keys, persistent=False)
  622. self.register_buffer(f"value_cache_{i}", layer.values, persistent=False)
  623. self.register_buffer(f"cumulative_length_{i}", layer.cumulative_length, persistent=False)
  624. def forward(
  625. self,
  626. input_ids: torch.LongTensor | None = None,
  627. inputs_embeds: torch.Tensor | None = None,
  628. cache_position: torch.Tensor | None = None,
  629. ) -> torch.Tensor:
  630. """
  631. Forward pass of the module, which is compatible with the ExecuTorch llm runner.
  632. Args:
  633. input_ids (`torch.Tensor`): Tensor representing current input token id to the module.
  634. inputs_embeds (`Optional[torch.Tensor]`): Tensor representing current input embeddings to the module.
  635. cache_position (`torch.Tensor`): Tensor representing current input position in the cache.
  636. Returns:
  637. torch.Tensor: Logits output from the model.
  638. """
  639. # Start by resetting static cache (it's needed to be able to run several generations with the same exported program,
  640. # as otherwise it's mutated in-place indefinitely - we cannot call reset in-between the `generate` as the program was
  641. # already exported)
  642. for layer in self.cache.layers:
  643. layer.cumulative_length.copy_(cache_position[0:1])
  644. # Forward pass with the model
  645. outputs = self.model(
  646. input_ids=input_ids,
  647. inputs_embeds=inputs_embeds,
  648. attention_mask=None,
  649. past_key_values=self.cache,
  650. use_cache=True,
  651. )
  652. # Return only the logits to simplify the export
  653. return outputs.logits
  654. def convert_and_export_with_cache(
  655. model: PreTrainedModel,
  656. example_input_ids: torch.Tensor | None = None,
  657. example_cache_position: torch.Tensor | None = None,
  658. dynamic_shapes: dict | None = None,
  659. strict: bool | None = None,
  660. ):
  661. """
  662. Convert a `PreTrainedModel` into an exportable module and export it using `torch.export`,
  663. ensuring the exported model is compatible with `ExecuTorch`.
  664. Args:
  665. model (`PreTrainedModel`): The pretrained model to be exported.
  666. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
  667. example_cache_position (`Optional[torch.Tensor]`): Example current cache position used by `torch.export`.
  668. dynamic_shapes(`Optional[dict]`): Dynamic shapes used by `torch.export`.
  669. strict(`Optional[bool]`): Flag to instruct `torch.export` to use `torchdynamo`.
  670. Returns:
  671. Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
  672. """
  673. import torch.export._trace
  674. with torch.no_grad():
  675. # TODO: The default inputs only work for text models. We need to add support for vision/audio models.
  676. example_input_ids = (
  677. example_input_ids
  678. if example_input_ids is not None
  679. else torch.tensor([[1]], dtype=torch.long, device=model.device)
  680. )
  681. example_cache_position = (
  682. example_cache_position
  683. if example_cache_position is not None
  684. else torch.tensor([0], dtype=torch.long, device=model.device)
  685. )
  686. if is_torch_greater_or_equal("2.6.0"):
  687. exported_program = torch.export.export(
  688. TorchExportableModuleWithStaticCache(model),
  689. args=(),
  690. kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
  691. dynamic_shapes=dynamic_shapes,
  692. strict=strict if strict is not None else True,
  693. )
  694. else:
  695. if dynamic_shapes is not None:
  696. logging.warning(
  697. "Dynamic shapes spec will be ignored by convert_and_export_with_cache for torch < 2.6.0."
  698. )
  699. if strict is not None:
  700. logging.warning("The strict flag will be ignored by convert_and_export_with_cache for torch < 2.6.0.")
  701. # We have to keep this path for BC.
  702. #
  703. # Due to issue https://github.com/pytorch/pytorch/issues/128394, we need to switch to use an internal
  704. # export API and pre_dispatch=False. Switch to use the public API once the issue is included in 2.5 release.
  705. exported_program = torch.export._trace._export(
  706. TorchExportableModuleWithStaticCache(model),
  707. args=(),
  708. kwargs={"input_ids": example_input_ids, "cache_position": example_cache_position},
  709. pre_dispatch=False,
  710. strict=True,
  711. )
  712. return exported_program
  713. class Seq2SeqLMEncoderExportableModule(torch.nn.Module):
  714. """
  715. A wrapper module designed to make a Seq2Seq LM encoder exportable with `torch.export`.
  716. This module ensures that the exported encoder model is compatible with ExecuTorch.
  717. """
  718. def __init__(self, encoder_model):
  719. super().__init__()
  720. self.encoder = encoder_model
  721. def forward(self, input_ids):
  722. return self.encoder(input_ids=input_ids).last_hidden_state
  723. class Seq2SeqLMDecoderExportableModuleWithStaticCache(torch.nn.Module):
  724. """
  725. A wrapper module designed to make a Seq2Seq LM decoder exportable with `torch.export`,
  726. specifically for use with static caching. This module ensures the exported decoder
  727. is compatible with ExecuTorch.
  728. """
  729. def __init__(self, model, max_static_cache_length, batch_size):
  730. super().__init__()
  731. # Get the decoder component
  732. self.decoder = model.get_decoder()
  733. self.lm_head = model.lm_head
  734. self.config = model.config
  735. # Detect the device of the exported models by checking a parameter
  736. # We'll use the model's device as the target device
  737. model_device = next(model.parameters()).device
  738. # Initialize static cache for decoder and DynamicCache for encoder
  739. self.static_cache = StaticCache(config=self.config, max_cache_len=max_static_cache_length)
  740. # Since StaticSlidingWindow have dynamic control flow that cannot be avoided, we have to replace them here by
  741. # simple StaticLayer... It means that any generation beyond the window is unfortunately unsupported
  742. for i, layer in enumerate(self.static_cache.layers):
  743. if isinstance(layer, StaticSlidingWindowLayer):
  744. self.static_cache.layers[i] = StaticLayer(max_static_cache_length)
  745. num_heads, head_dim = get_head_shapes(self.config)
  746. self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, model_device)
  747. self.cache = EncoderDecoderCache(self.static_cache, DynamicCache(config=self.config))
  748. register_dynamic_cache_export_support()
  749. # Register cache buffers to make them exportable
  750. for i, layer in enumerate(self.static_cache.layers):
  751. self.register_buffer(f"key_cache_{i}", layer.keys, persistent=False)
  752. self.register_buffer(f"value_cache_{i}", layer.values, persistent=False)
  753. self.register_buffer(f"cumulative_length_{i}", layer.cumulative_length, persistent=False)
  754. def forward(self, decoder_input_ids, encoder_hidden_states, cache_position):
  755. # Start by resetting static cache (it's needed to be able to run several generations with the same exported program,
  756. # as otherwise it's mutated in-place indefinitely - we cannot call reset in-between the `generate` as the program was
  757. # already exported)
  758. for layer in self.static_cache.layers:
  759. layer.cumulative_length.copy_(cache_position[0:1])
  760. # Get outputs from decoder
  761. outputs = self.decoder(
  762. input_ids=decoder_input_ids,
  763. encoder_hidden_states=encoder_hidden_states,
  764. past_key_values=self.cache,
  765. use_cache=True,
  766. )
  767. # Apply language model head
  768. lm_logits = self.lm_head(outputs[0])
  769. return lm_logits
  770. class Seq2SeqLMExportableModule(torch.nn.Module):
  771. def __init__(
  772. self, model, batch_size=1, max_hidden_seq_length=4096, cache_implementation="static", max_cache_length=1024
  773. ):
  774. super().__init__()
  775. self.full_model = model
  776. self.encoder = model.get_encoder()
  777. self.config = model.config
  778. self.max_hidden_seq_length = max_hidden_seq_length
  779. self.generation_config = GenerationConfig(
  780. use_cache=True,
  781. max_length=max_cache_length,
  782. cache_implementation=cache_implementation,
  783. cache_config={
  784. "batch_size": batch_size,
  785. "max_cache_len": max_cache_length,
  786. },
  787. eos_token_id=model.generation_config.eos_token_id,
  788. )
  789. self.exported_encoder = None
  790. self.exported_decoder = None
  791. def _export_encoder(self, encoder_input_ids):
  792. wrapped_encoder = Seq2SeqLMEncoderExportableModule(self.encoder).to(self.full_model.device).eval()
  793. # Define dynamic sequence length for encoder
  794. seq_len_dim = torch.export.Dim("encoder_seq_length", max=self.max_hidden_seq_length)
  795. # Export the encoder
  796. with torch.no_grad():
  797. exported_encoder = torch.export.export(
  798. wrapped_encoder, (encoder_input_ids,), dynamic_shapes={"input_ids": {1: seq_len_dim}}, strict=True
  799. )
  800. return exported_encoder
  801. def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_position):
  802. target_device = self.full_model.device
  803. wrapped_decoder = (
  804. Seq2SeqLMDecoderExportableModuleWithStaticCache(
  805. model=self.full_model,
  806. max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
  807. batch_size=self.generation_config.cache_config.get("batch_size"),
  808. )
  809. .to(target_device)
  810. .eval()
  811. )
  812. # Move input tensors to the same device as the wrapped decoder
  813. decoder_input_ids = decoder_input_ids.to(target_device)
  814. encoder_hidden_states = encoder_hidden_states.to(target_device)
  815. cache_position = cache_position.to(target_device)
  816. # Define dynamic dimension for encoder output sequence length
  817. encoder_seq_len_dim = torch.export.Dim("encoder_hidden_seq_length", max=self.max_hidden_seq_length)
  818. # Export the decoder
  819. with torch.no_grad():
  820. exported_decoder = torch.export.export(
  821. wrapped_decoder,
  822. (decoder_input_ids, encoder_hidden_states, cache_position),
  823. dynamic_shapes={
  824. "decoder_input_ids": None,
  825. "encoder_hidden_states": {1: encoder_seq_len_dim},
  826. "cache_position": None,
  827. },
  828. strict=True,
  829. )
  830. return exported_decoder
  831. def export(self, encoder_input_ids=None, decoder_input_ids=None, encoder_hidden_states=None, cache_position=None):
  832. device = self.full_model.device
  833. example_encoder_input_ids = (
  834. encoder_input_ids
  835. if encoder_input_ids is not None
  836. else torch.ones((1, 10), dtype=torch.long, device=device)
  837. )
  838. example_decoder_input_ids = (
  839. decoder_input_ids
  840. if decoder_input_ids is not None
  841. else torch.tensor([[0]], dtype=torch.long, device=device)
  842. ) # Start token
  843. example_cache_position = (
  844. cache_position if cache_position is not None else torch.tensor([0], dtype=torch.long, device=device)
  845. )
  846. example_encoder_hidden_states = (
  847. encoder_hidden_states
  848. if encoder_hidden_states is not None
  849. else torch.zeros(
  850. (self.generation_config.cache_config.get("batch_size"), 10, self.config.d_model),
  851. dtype=torch.float32,
  852. device=device,
  853. )
  854. )
  855. self.exported_encoder = self._export_encoder(example_encoder_input_ids)
  856. self.exported_decoder = self._export_decoder(
  857. example_decoder_input_ids, example_encoder_hidden_states, example_cache_position
  858. )
  859. # Return self to allow chaining
  860. return self
  861. def generate(self, prompt_token_ids, max_new_tokens):
  862. with torch.no_grad():
  863. model_device = self.full_model.device
  864. # Move input to the model's device if it's on a different device
  865. if prompt_token_ids.device != model_device:
  866. prompt_token_ids = prompt_token_ids.to(model_device)
  867. # Run encoder
  868. encoder_output = self.exported_encoder.module()(prompt_token_ids)
  869. # Initialize with start token (0 for T5) on the correct device
  870. decoder_input_ids = torch.tensor([[0]], dtype=torch.long, device=model_device)
  871. generated_ids = [0]
  872. # Generate tokens one by one
  873. for i in range(max_new_tokens - 1):
  874. # Run decoder for next token prediction
  875. logits = self.exported_decoder.module()(
  876. decoder_input_ids, encoder_output, torch.tensor([i], dtype=torch.long, device=model_device)
  877. )
  878. # Get next token
  879. next_token = torch.argmax(logits[:, -1, :], dim=-1).item()
  880. generated_ids.append(next_token)
  881. # Update input for next iteration on the correct device
  882. decoder_input_ids = torch.tensor([[next_token]], dtype=torch.long, device=model_device)
  883. # Check if EOS token
  884. if next_token == self.generation_config.eos_token_id:
  885. break
  886. return generated_ids
  887. def export_with_dynamic_cache(
  888. model: PreTrainedModel,
  889. example_input_ids: torch.Tensor | None = None,
  890. example_attention_mask: torch.Tensor | None = None,
  891. ):
  892. """
  893. Export a model with DynamicCache using `torch.export`, ensuring the exported model is compatible with `ExecuTorch`.
  894. Args:
  895. model (`PreTrainedModel`): The pretrained model to be exported.
  896. example_input_ids (`Optional[torch.Tensor]`): Example input token id used by `torch.export`.
  897. example_attention_mask (`Optional[torch.Tensor]`): Example attention mask used by `torch.export`.
  898. Returns:
  899. Exported program (`torch.export.ExportedProgram`): The exported program generated via `torch.export`.
  900. """
  901. register_dynamic_cache_export_support()
  902. with torch.no_grad():
  903. exported_program = torch.export.export(
  904. model,
  905. (),
  906. {
  907. "input_ids": example_input_ids,
  908. "attention_mask": example_attention_mask,
  909. "past_key_values": DynamicCache(config=model.config),
  910. "use_cache": True,
  911. },
  912. strict=False,
  913. )
  914. return exported_program
  915. def register_dynamic_cache_export_support():
  916. """
  917. Utilities for `DynamicCache` <> torch.export support
  918. """
  919. try:
  920. torch.utils._pytree.register_pytree_node(
  921. DynamicCache,
  922. lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
  923. _unflatten_dynamic_cache,
  924. serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
  925. flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
  926. _get_cache_dict(dynamic_cache)
  927. ),
  928. )
  929. # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
  930. torch.fx._pytree.register_pytree_flatten_spec(
  931. DynamicCache,
  932. lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
  933. )
  934. # Catching this in case there are multiple runs for some test runs
  935. except ValueError as e:
  936. if "already registered as pytree node" not in str(e):
  937. raise
  938. def _get_cache_dict(cache: DynamicCache):
  939. """Convert cache to dictionary format for pytree operations."""
  940. if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
  941. raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
  942. if not is_torch_greater_or_equal_than_2_6:
  943. logging.warning("DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions.")
  944. return {
  945. "key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
  946. "value_cache": [layer.values for layer in cache.layers if layer.values is not None],
  947. }
  948. def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
  949. dictionary = torch.utils._pytree._dict_unflatten(values, context)
  950. cache = DynamicCache()
  951. # Reconstruct layers from keys and values lists
  952. key_list = dictionary.get("key_cache", [])
  953. value_list = dictionary.get("value_cache", [])
  954. for idx in range(max(len(key_list), len(value_list))):
  955. key = key_list[idx] if idx < len(key_list) else None
  956. value = value_list[idx] if idx < len(value_list) else None
  957. cache.update(key, value, idx)
  958. return cache