generation.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # Copyright (c) 2023, Albert Gu, Tri Dao.
  2. import gc
  3. import time
  4. from collections import namedtuple
  5. from dataclasses import dataclass, field
  6. from functools import partial
  7. from typing import Callable, Optional, Sequence, Union
  8. import torch
  9. import torch.nn.functional as F
  10. from einops import rearrange, repeat
  11. from torch import Tensor
  12. from torch.profiler import ProfilerActivity, profile, record_function
  13. from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
  14. @dataclass
  15. class InferenceParams:
  16. """Inference parameters that are passed to the main model in order
  17. to efficienly calculate and store the context during inference."""
  18. max_seqlen: int
  19. max_batch_size: int
  20. seqlen_offset: int = 0
  21. batch_size_offset: int = 0
  22. key_value_memory_dict: dict = field(default_factory=dict)
  23. lengths_per_sample: Optional[Tensor] = None
  24. def reset(self, max_seqlen, max_batch_size):
  25. self.max_seqlen = max_seqlen
  26. self.max_batch_size = max_batch_size
  27. self.seqlen_offset = 0
  28. if self.lengths_per_sample is not None:
  29. self.lengths_per_sample.zero_()
  30. def modify_logits_for_min_p_filtering(logits, min_p):
  31. """Set the logits for none min_p values to -inf. Done in-place."""
  32. if min_p <= 0.0 or min_p >= 1.0:
  33. return
  34. indices_to_remove = logits < min_p
  35. logits.masked_fill_(indices_to_remove, float("-Inf"))
  36. # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
  37. # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
  38. def modify_logits_for_top_k_filtering(logits, top_k):
  39. """Set the logits for none top-k values to -inf. Done in-place."""
  40. indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
  41. logits.masked_fill_(indices_to_remove, float("-Inf"))
  42. # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
  43. # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
  44. def modify_logits_for_top_p_filtering(logits, top_p):
  45. """Set the logits for none top-p values to -inf. Done in-place."""
  46. if top_p <= 0.0 or top_p >= 1.0:
  47. return
  48. # First sort and calculate cumulative sum of probabilities.
  49. sorted_logits, sorted_indices = torch.sort(logits, descending=False)
  50. cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
  51. # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
  52. sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
  53. # scatter sorted tensors to original indexing
  54. indices_to_remove = sorted_indices_to_remove.scatter(
  55. 1, sorted_indices, sorted_indices_to_remove
  56. )
  57. logits.masked_fill_(indices_to_remove, float("-inf"))
  58. def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
  59. """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
  60. logits: (batch_size, vocab_size)
  61. prev_output_tokens: (batch_size, seq_len)
  62. """
  63. if repetition_penalty == 1.0:
  64. return logits
  65. score = torch.gather(logits, 1, prev_output_tokens)
  66. # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
  67. score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
  68. logits.scatter_(1, prev_output_tokens, score)
  69. return logits
  70. def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
  71. """Sample from top-k logits.
  72. Arguments:
  73. logits: Tensor of shape (batch_size, vocab_size)
  74. """
  75. if top_k == 1: # Short-circuit for greedy decoding
  76. return logits.argmax(dim=-1)
  77. else:
  78. if top_p > 0.0:
  79. assert top_p <= 1.0, "top-p should be in (0, 1]."
  80. if top_k > 0:
  81. top_k = min(top_k, logits.size(-1)) # Safety check
  82. logits_top, indices = torch.topk(logits, top_k, dim=-1)
  83. if temperature != 1.0:
  84. logits_top /= temperature
  85. modify_logits_for_top_p_filtering(logits_top, top_p)
  86. return indices[
  87. torch.arange(indices.shape[0], device=indices.device),
  88. torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
  89. ]
  90. else:
  91. if min_p > 0.0:
  92. logits_top = logits.clone()
  93. max_prob = logits_top[..., 0].item()
  94. min_prob = max_prob * min_p
  95. modify_logits_for_min_p_filtering(logits_top, min_prob)
  96. if temperature != 1.0:
  97. logits_top /= temperature
  98. return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
  99. # Clone so that when we modify for top_p we don't change the original logits
  100. logits_top = logits / temperature if temperature != 1.0 else logits.clone()
  101. modify_logits_for_top_p_filtering(logits_top, top_p)
  102. return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
  103. dim=-1
  104. )
  105. @torch.inference_mode()
  106. def decode(
  107. input_ids,
  108. model,
  109. max_length,
  110. top_k=1,
  111. top_p=0.0,
  112. min_p=0.0,
  113. temperature=1.0,
  114. repetition_penalty=1.0,
  115. eos_token_id=None,
  116. teacher_outputs=None,
  117. vocab_size=None,
  118. cg=False,
  119. enable_timing=False,
  120. output_scores=False,
  121. streamer: Optional[TextStreamer] = None
  122. ):
  123. """Decoding, either greedy or with top-k or top-p sampling.
  124. If top-k = 0, don't limit the number of candidates (pure sampling).
  125. Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
  126. then top-p.
  127. We assume that all sequences in the same batch have the same length.
  128. Arguments:
  129. input_ids: (batch, seq_len)
  130. max_length: int
  131. teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
  132. logits, the next token is taken from the teacher_outputs. Useful for testing.
  133. Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
  134. sequences: (batch, max_length)
  135. scores: tuples of (batch, vocab_size)
  136. """
  137. if streamer is not None:
  138. streamer.put(input_ids.cpu())
  139. batch_size, seqlen_og = input_ids.shape
  140. teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
  141. if cg:
  142. if not hasattr(model, "_decoding_cache"):
  143. model._decoding_cache = None
  144. model._decoding_cache = update_graph_cache(
  145. model,
  146. model._decoding_cache,
  147. batch_size,
  148. seqlen_og,
  149. max_length,
  150. )
  151. inference_params = model._decoding_cache.inference_params
  152. inference_params.reset(max_length, batch_size)
  153. else:
  154. inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
  155. def get_logits(input_ids, inference_params):
  156. decoding = inference_params.seqlen_offset > 0
  157. if decoding:
  158. position_ids = torch.full(
  159. (batch_size, 1),
  160. inference_params.seqlen_offset,
  161. dtype=torch.long,
  162. device=input_ids.device,
  163. )
  164. else:
  165. position_ids = None
  166. if not cg or not decoding:
  167. logits = model(
  168. input_ids,
  169. position_ids=position_ids,
  170. inference_params=inference_params,
  171. num_last_tokens=1,
  172. ).logits.squeeze(dim=1)
  173. else:
  174. logits = model._decoding_cache.run(
  175. input_ids, position_ids, inference_params.seqlen_offset
  176. ).squeeze(dim=1)
  177. return logits[..., :vocab_size] if vocab_size is not None else logits
  178. def sample_tokens(logits, inference_params):
  179. if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
  180. token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
  181. else:
  182. token = teacher_outputs[:, inference_params.seqlen_offset]
  183. # return rearrange(token, "b -> b 1")
  184. return token.unsqueeze(1)
  185. def should_stop(current_token, inference_params):
  186. if inference_params.seqlen_offset == 0:
  187. return False
  188. if eos_token_id is not None and (current_token == eos_token_id).all():
  189. return True
  190. if inference_params.seqlen_offset >= max_length - 1:
  191. return True
  192. return False
  193. start = torch.cuda.Event(enable_timing=enable_timing)
  194. end = torch.cuda.Event(enable_timing=enable_timing)
  195. if enable_timing:
  196. start.record()
  197. scores, sequences = [], [input_ids]
  198. sequences_cat = input_ids
  199. while not should_stop(sequences[-1], inference_params):
  200. logits = get_logits(sequences[-1], inference_params)
  201. if output_scores:
  202. scores.append(logits.clone())
  203. inference_params.seqlen_offset += sequences[-1].shape[1]
  204. if repetition_penalty == 1.0:
  205. sampled_tokens = sample_tokens(logits, inference_params)
  206. else:
  207. logits = modify_logit_for_repetition_penalty(
  208. logits, sequences_cat, repetition_penalty
  209. )
  210. sampled_tokens = sample_tokens(logits, inference_params)
  211. sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
  212. sequences.append(sampled_tokens)
  213. if streamer is not None:
  214. streamer.put(sampled_tokens.cpu())
  215. if streamer is not None:
  216. streamer.end()
  217. if enable_timing:
  218. end.record()
  219. torch.cuda.synchronize()
  220. print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
  221. output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
  222. return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
  223. class GenerationMixin:
  224. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  225. raise NotImplementedError
  226. def generate(
  227. self,
  228. input_ids,
  229. max_length,
  230. top_k=1,
  231. top_p=0.0,
  232. min_p=0.0,
  233. temperature=1.0,
  234. return_dict_in_generate=False,
  235. output_scores=False,
  236. **kwargs,
  237. ):
  238. output = decode(
  239. input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
  240. )
  241. if not output_scores:
  242. output.scores = None
  243. return output if return_dict_in_generate else output.sequences
  244. @dataclass
  245. class DecodingCGCache:
  246. max_batch_size: int = 0
  247. max_seqlen: int = 0
  248. device = None
  249. dtype = None
  250. callables: dict = field(default_factory=dict)
  251. mempool = None
  252. inference_params: Optional[InferenceParams] = None
  253. run: Optional[Callable] = None
  254. @torch.inference_mode()
  255. def update_graph_cache(
  256. model,
  257. cache,
  258. batch_size,
  259. seqlen_og,
  260. max_seqlen,
  261. decoding_seqlens=(1,),
  262. dtype=None,
  263. n_warmups=2,
  264. ):
  265. if cache is None:
  266. cache = DecodingCGCache()
  267. param_example = next(iter(model.parameters()))
  268. device = param_example.device
  269. if dtype is None:
  270. dtype = param_example.dtype
  271. if (
  272. (device, dtype) != (cache.device, cache.dtype)
  273. or batch_size > cache.max_batch_size
  274. or max_seqlen > cache.max_seqlen
  275. ): # Invalidate the cache
  276. cache.callables = {}
  277. cache.mempool = None
  278. cache.inference_params = None
  279. gc.collect()
  280. cache.device, cache.dtype = device, dtype
  281. cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
  282. assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
  283. inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
  284. lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
  285. cache.inference_params = InferenceParams(
  286. max_seqlen=max_seqlen,
  287. max_batch_size=batch_size,
  288. seqlen_offset=seqlen_og,
  289. key_value_memory_dict=inf_cache,
  290. lengths_per_sample=lengths_per_sample,
  291. )
  292. cache.mempool = torch.cuda.graphs.graph_pool_handle()
  293. for decoding_seqlen in decoding_seqlens:
  294. if (batch_size, decoding_seqlen) not in cache.callables:
  295. cache.callables[batch_size, decoding_seqlen] = capture_graph(
  296. model,
  297. cache.inference_params,
  298. batch_size,
  299. max_seqlen,
  300. decoding_seqlen=decoding_seqlen,
  301. mempool=cache.mempool,
  302. n_warmups=n_warmups,
  303. )
  304. def dispatch(input_ids, position_ids, seqlen):
  305. batch_size, decoding_seqlen = input_ids.shape[:2]
  306. return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
  307. cache.run = dispatch
  308. cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
  309. return cache
  310. def capture_graph(
  311. model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
  312. ):
  313. device = next(iter(model.parameters())).device
  314. input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
  315. position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
  316. seqlen_offset_og = inference_params.seqlen_offset
  317. inference_params.seqlen_offset = max_seqlen - decoding_seqlen
  318. inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
  319. # Warmup before capture
  320. s = torch.cuda.Stream()
  321. s.wait_stream(torch.cuda.current_stream())
  322. with torch.cuda.stream(s):
  323. for _ in range(n_warmups):
  324. logits = model(
  325. input_ids,
  326. position_ids=position_ids,
  327. inference_params=inference_params,
  328. num_last_tokens=decoding_seqlen,
  329. ).logits
  330. s.synchronize()
  331. # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
  332. # which requires that graph launch and non-captured launch to not overlap (I think,
  333. # that's how I interpret the documentation). I'm not sure if this is required.
  334. if torch.distributed.is_initialized():
  335. torch.distributed.barrier()
  336. torch.cuda.current_stream().wait_stream(s)
  337. # Captures the graph
  338. # To allow capture, automatically sets a side stream as the current stream in the context
  339. graph = torch.cuda.CUDAGraph()
  340. with torch.cuda.graph(graph, pool=mempool):
  341. logits = model(
  342. input_ids,
  343. position_ids=position_ids,
  344. inference_params=inference_params,
  345. num_last_tokens=decoding_seqlen,
  346. ).logits
  347. def run(new_input_ids, new_position_ids, seqlen):
  348. inference_params.lengths_per_sample[:] = seqlen
  349. input_ids.copy_(new_input_ids)
  350. position_ids.copy_(new_position_ids)
  351. graph.replay()
  352. return logits.clone()
  353. inference_params.seqlen_offset = seqlen_offset_og
  354. return run