| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390 |
- # Copyright (c) 2023, Albert Gu, Tri Dao.
- import gc
- import time
- from collections import namedtuple
- from dataclasses import dataclass, field
- from functools import partial
- from typing import Callable, Optional, Sequence, Union
- import torch
- import torch.nn.functional as F
- from einops import rearrange, repeat
- from torch import Tensor
- from torch.profiler import ProfilerActivity, profile, record_function
- from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
- @dataclass
- class InferenceParams:
- """Inference parameters that are passed to the main model in order
- to efficienly calculate and store the context during inference."""
- max_seqlen: int
- max_batch_size: int
- seqlen_offset: int = 0
- batch_size_offset: int = 0
- key_value_memory_dict: dict = field(default_factory=dict)
- lengths_per_sample: Optional[Tensor] = None
- def reset(self, max_seqlen, max_batch_size):
- self.max_seqlen = max_seqlen
- self.max_batch_size = max_batch_size
- self.seqlen_offset = 0
- if self.lengths_per_sample is not None:
- self.lengths_per_sample.zero_()
- def modify_logits_for_min_p_filtering(logits, min_p):
- """Set the logits for none min_p values to -inf. Done in-place."""
- if min_p <= 0.0 or min_p >= 1.0:
- return
- indices_to_remove = logits < min_p
- logits.masked_fill_(indices_to_remove, float("-Inf"))
- # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
- # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231
- def modify_logits_for_top_k_filtering(logits, top_k):
- """Set the logits for none top-k values to -inf. Done in-place."""
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
- logits.masked_fill_(indices_to_remove, float("-Inf"))
- # https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py
- # https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170
- def modify_logits_for_top_p_filtering(logits, top_p):
- """Set the logits for none top-p values to -inf. Done in-place."""
- if top_p <= 0.0 or top_p >= 1.0:
- return
- # First sort and calculate cumulative sum of probabilities.
- sorted_logits, sorted_indices = torch.sort(logits, descending=False)
- cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
- # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
- sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
- # scatter sorted tensors to original indexing
- indices_to_remove = sorted_indices_to_remove.scatter(
- 1, sorted_indices, sorted_indices_to_remove
- )
- logits.masked_fill_(indices_to_remove, float("-inf"))
- def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0):
- """Apply repetition penalty. See https://arxiv.org/abs/1909.05858
- logits: (batch_size, vocab_size)
- prev_output_tokens: (batch_size, seq_len)
- """
- if repetition_penalty == 1.0:
- return logits
- score = torch.gather(logits, 1, prev_output_tokens)
- # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
- score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
- logits.scatter_(1, prev_output_tokens, score)
- return logits
- def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0):
- """Sample from top-k logits.
- Arguments:
- logits: Tensor of shape (batch_size, vocab_size)
- """
- if top_k == 1: # Short-circuit for greedy decoding
- return logits.argmax(dim=-1)
- else:
- if top_p > 0.0:
- assert top_p <= 1.0, "top-p should be in (0, 1]."
- if top_k > 0:
- top_k = min(top_k, logits.size(-1)) # Safety check
- logits_top, indices = torch.topk(logits, top_k, dim=-1)
- if temperature != 1.0:
- logits_top /= temperature
- modify_logits_for_top_p_filtering(logits_top, top_p)
- return indices[
- torch.arange(indices.shape[0], device=indices.device),
- torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1),
- ]
- else:
- if min_p > 0.0:
- logits_top = logits.clone()
- max_prob = logits_top[..., 0].item()
- min_prob = max_prob * min_p
- modify_logits_for_min_p_filtering(logits_top, min_prob)
- if temperature != 1.0:
- logits_top /= temperature
- return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1)
- # Clone so that when we modify for top_p we don't change the original logits
- logits_top = logits / temperature if temperature != 1.0 else logits.clone()
- modify_logits_for_top_p_filtering(logits_top, top_p)
- return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(
- dim=-1
- )
- @torch.inference_mode()
- def decode(
- input_ids,
- model,
- max_length,
- top_k=1,
- top_p=0.0,
- min_p=0.0,
- temperature=1.0,
- repetition_penalty=1.0,
- eos_token_id=None,
- teacher_outputs=None,
- vocab_size=None,
- cg=False,
- enable_timing=False,
- output_scores=False,
- streamer: Optional[TextStreamer] = None
- ):
- """Decoding, either greedy or with top-k or top-p sampling.
- If top-k = 0, don't limit the number of candidates (pure sampling).
- Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first,
- then top-p.
- We assume that all sequences in the same batch have the same length.
- Arguments:
- input_ids: (batch, seq_len)
- max_length: int
- teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
- logits, the next token is taken from the teacher_outputs. Useful for testing.
- Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
- sequences: (batch, max_length)
- scores: tuples of (batch, vocab_size)
- """
- if streamer is not None:
- streamer.put(input_ids.cpu())
- batch_size, seqlen_og = input_ids.shape
- teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0
- if cg:
- if not hasattr(model, "_decoding_cache"):
- model._decoding_cache = None
- model._decoding_cache = update_graph_cache(
- model,
- model._decoding_cache,
- batch_size,
- seqlen_og,
- max_length,
- )
- inference_params = model._decoding_cache.inference_params
- inference_params.reset(max_length, batch_size)
- else:
- inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size)
- def get_logits(input_ids, inference_params):
- decoding = inference_params.seqlen_offset > 0
- if decoding:
- position_ids = torch.full(
- (batch_size, 1),
- inference_params.seqlen_offset,
- dtype=torch.long,
- device=input_ids.device,
- )
- else:
- position_ids = None
- if not cg or not decoding:
- logits = model(
- input_ids,
- position_ids=position_ids,
- inference_params=inference_params,
- num_last_tokens=1,
- ).logits.squeeze(dim=1)
- else:
- logits = model._decoding_cache.run(
- input_ids, position_ids, inference_params.seqlen_offset
- ).squeeze(dim=1)
- return logits[..., :vocab_size] if vocab_size is not None else logits
- def sample_tokens(logits, inference_params):
- if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset:
- token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature)
- else:
- token = teacher_outputs[:, inference_params.seqlen_offset]
- # return rearrange(token, "b -> b 1")
- return token.unsqueeze(1)
- def should_stop(current_token, inference_params):
- if inference_params.seqlen_offset == 0:
- return False
- if eos_token_id is not None and (current_token == eos_token_id).all():
- return True
- if inference_params.seqlen_offset >= max_length - 1:
- return True
- return False
- start = torch.cuda.Event(enable_timing=enable_timing)
- end = torch.cuda.Event(enable_timing=enable_timing)
- if enable_timing:
- start.record()
- scores, sequences = [], [input_ids]
- sequences_cat = input_ids
- while not should_stop(sequences[-1], inference_params):
- logits = get_logits(sequences[-1], inference_params)
- if output_scores:
- scores.append(logits.clone())
- inference_params.seqlen_offset += sequences[-1].shape[1]
- if repetition_penalty == 1.0:
- sampled_tokens = sample_tokens(logits, inference_params)
- else:
- logits = modify_logit_for_repetition_penalty(
- logits, sequences_cat, repetition_penalty
- )
- sampled_tokens = sample_tokens(logits, inference_params)
- sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1)
- sequences.append(sampled_tokens)
- if streamer is not None:
- streamer.put(sampled_tokens.cpu())
- if streamer is not None:
- streamer.end()
- if enable_timing:
- end.record()
- torch.cuda.synchronize()
- print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
- output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
- return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
- class GenerationMixin:
- def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
- raise NotImplementedError
- def generate(
- self,
- input_ids,
- max_length,
- top_k=1,
- top_p=0.0,
- min_p=0.0,
- temperature=1.0,
- return_dict_in_generate=False,
- output_scores=False,
- **kwargs,
- ):
- output = decode(
- input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs
- )
- if not output_scores:
- output.scores = None
- return output if return_dict_in_generate else output.sequences
- @dataclass
- class DecodingCGCache:
- max_batch_size: int = 0
- max_seqlen: int = 0
- device = None
- dtype = None
- callables: dict = field(default_factory=dict)
- mempool = None
- inference_params: Optional[InferenceParams] = None
- run: Optional[Callable] = None
- @torch.inference_mode()
- def update_graph_cache(
- model,
- cache,
- batch_size,
- seqlen_og,
- max_seqlen,
- decoding_seqlens=(1,),
- dtype=None,
- n_warmups=2,
- ):
- if cache is None:
- cache = DecodingCGCache()
- param_example = next(iter(model.parameters()))
- device = param_example.device
- if dtype is None:
- dtype = param_example.dtype
- if (
- (device, dtype) != (cache.device, cache.dtype)
- or batch_size > cache.max_batch_size
- or max_seqlen > cache.max_seqlen
- ): # Invalidate the cache
- cache.callables = {}
- cache.mempool = None
- cache.inference_params = None
- gc.collect()
- cache.device, cache.dtype = device, dtype
- cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen
- assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache"
- inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype)
- lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device)
- cache.inference_params = InferenceParams(
- max_seqlen=max_seqlen,
- max_batch_size=batch_size,
- seqlen_offset=seqlen_og,
- key_value_memory_dict=inf_cache,
- lengths_per_sample=lengths_per_sample,
- )
- cache.mempool = torch.cuda.graphs.graph_pool_handle()
- for decoding_seqlen in decoding_seqlens:
- if (batch_size, decoding_seqlen) not in cache.callables:
- cache.callables[batch_size, decoding_seqlen] = capture_graph(
- model,
- cache.inference_params,
- batch_size,
- max_seqlen,
- decoding_seqlen=decoding_seqlen,
- mempool=cache.mempool,
- n_warmups=n_warmups,
- )
- def dispatch(input_ids, position_ids, seqlen):
- batch_size, decoding_seqlen = input_ids.shape[:2]
- return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen)
- cache.run = dispatch
- cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing
- return cache
- def capture_graph(
- model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2
- ):
- device = next(iter(model.parameters())).device
- input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
- position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device)
- seqlen_offset_og = inference_params.seqlen_offset
- inference_params.seqlen_offset = max_seqlen - decoding_seqlen
- inference_params.lengths_per_sample[:] = inference_params.seqlen_offset
- # Warmup before capture
- s = torch.cuda.Stream()
- s.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(s):
- for _ in range(n_warmups):
- logits = model(
- input_ids,
- position_ids=position_ids,
- inference_params=inference_params,
- num_last_tokens=decoding_seqlen,
- ).logits
- s.synchronize()
- # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
- # which requires that graph launch and non-captured launch to not overlap (I think,
- # that's how I interpret the documentation). I'm not sure if this is required.
- if torch.distributed.is_initialized():
- torch.distributed.barrier()
- torch.cuda.current_stream().wait_stream(s)
- # Captures the graph
- # To allow capture, automatically sets a side stream as the current stream in the context
- graph = torch.cuda.CUDAGraph()
- with torch.cuda.graph(graph, pool=mempool):
- logits = model(
- input_ids,
- position_ids=position_ids,
- inference_params=inference_params,
- num_last_tokens=decoding_seqlen,
- ).logits
- def run(new_input_ids, new_position_ids, seqlen):
- inference_params.lengths_per_sample[:] = seqlen
- input_ids.copy_(new_input_ids)
- position_ids.copy_(new_position_ids)
- graph.replay()
- return logits.clone()
- inference_params.seqlen_offset = seqlen_offset_og
- return run
|