generation_csm.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. from typing import TYPE_CHECKING, Any, Optional
  16. import torch
  17. import torch.nn as nn
  18. from ...generation import (
  19. GenerateDecoderOnlyOutput,
  20. GenerationConfig,
  21. GenerationMixin,
  22. GenerationMode,
  23. )
  24. from ...generation.logits_process import LogitsProcessorList
  25. from ...generation.stopping_criteria import MaxLengthCriteria, StoppingCriteriaList
  26. from ...generation.utils import GenerateNonBeamOutput
  27. from ...utils import logging
  28. if TYPE_CHECKING:
  29. from ...generation.streamers import BaseStreamer
  30. logger = logging.get_logger(__name__)
  31. @dataclass
  32. class CsmGenerateOutput(GenerateDecoderOnlyOutput):
  33. """
  34. Outputs of CsmForConditionalGeneration.generate.
  35. Args:
  36. sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  37. The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
  38. if all batches finished early due to the `eos_token_id`.
  39. scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True`):
  40. Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  41. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  42. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  43. logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
  44. Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
  45. at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
  46. each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
  47. attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
  48. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  49. `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
  50. hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
  51. Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
  52. `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
  53. past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
  54. Returns the model cache, used to speed up decoding. Different models have a different cache format, check
  55. audio (`list(torch.FloatTensor)` of length `batch_size`):
  56. The generated audio.
  57. """
  58. audio: list[torch.Tensor] | None = None
  59. class CsmGenerationMixin(GenerationMixin):
  60. def _get_stopping_criteria(
  61. self,
  62. *args,
  63. **kwargs,
  64. ) -> StoppingCriteriaList:
  65. criteria = super()._get_stopping_criteria(*args, **kwargs)
  66. kept_criteria = StoppingCriteriaList()
  67. for criterion in criteria:
  68. if not isinstance(criterion, MaxLengthCriteria):
  69. logger.warning(
  70. f"Csm does not support {criterion.__class__.__name__} stopping criteria, it will be ignored."
  71. )
  72. else:
  73. kept_criteria.append(criterion)
  74. return kept_criteria
  75. def _prepare_generation_config(
  76. self, generation_config: GenerationConfig | None, **kwargs: Any
  77. ) -> tuple[GenerationConfig, dict]:
  78. """
  79. This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
  80. It ensures that the depth decoder generation config is initialized and that passed args as depth_decoder_* are properly handled.
  81. """
  82. # extract depth decoder kwargs and remove them from the main kwargs
  83. depth_decoder_kwargs = {
  84. k[len("depth_decoder_") :]: v for k, v in kwargs.items() if k.startswith("depth_decoder_")
  85. }
  86. # remove the depth decoder keys from the original kwargs
  87. kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
  88. # initialize the generation config
  89. generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
  90. self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
  91. # ensure the depth decoder generation config is valid
  92. depth_decoder_min_new_tokens = getattr(self.depth_decoder.generation_config, "min_new_tokens") or (
  93. self.config.num_codebooks - 1
  94. )
  95. depth_decoder_max_new_tokens = getattr(self.depth_decoder.generation_config, "max_new_tokens") or (
  96. self.config.num_codebooks - 1
  97. )
  98. if {depth_decoder_min_new_tokens, depth_decoder_max_new_tokens} != {self.config.num_codebooks - 1}:
  99. raise ValueError(
  100. f"depth_decoder_generation_config's min_new_tokens ({depth_decoder_min_new_tokens}) and max_new_tokens ({depth_decoder_max_new_tokens}) must be equal to self.config.num_codebooks - 1 ({self.config.num_codebooks - 1})"
  101. )
  102. elif self.depth_decoder.generation_config.return_dict_in_generate:
  103. logger.warning(
  104. "depth_decoder_generation_config.return_dict_in_generate is set to True, but this will be ignored as the depth decoder model does not return a dictionary in generate"
  105. )
  106. self.depth_decoder.generation_config.return_dict_in_generate = False
  107. self.depth_decoder.generation_config.min_new_tokens = depth_decoder_min_new_tokens
  108. self.depth_decoder.generation_config.max_new_tokens = depth_decoder_max_new_tokens
  109. # Monkey patch the get_generation_mode method to support CSM model
  110. original_get_generation_mode = generation_config.get_generation_mode
  111. def patched_get_generation_mode(assistant_model=None):
  112. generation_mode = original_get_generation_mode(assistant_model)
  113. if generation_mode not in [GenerationMode.GREEDY_SEARCH, GenerationMode.SAMPLE]:
  114. raise ValueError(
  115. f"Generation mode {generation_mode} is not supported for CSM model. Please set generation parameters to use greedy or sampling generation."
  116. )
  117. return generation_mode
  118. generation_config.get_generation_mode = patched_get_generation_mode
  119. return generation_config, model_kwargs
  120. def _sample(
  121. self,
  122. input_ids: torch.LongTensor,
  123. logits_processor: LogitsProcessorList,
  124. stopping_criteria: StoppingCriteriaList,
  125. generation_config: GenerationConfig,
  126. synced_gpus: bool = False,
  127. streamer: Optional["BaseStreamer"] = None,
  128. **model_kwargs,
  129. ) -> GenerateNonBeamOutput | torch.LongTensor:
  130. """
  131. This method overrides [~generation.utils.GenerationMixin._sample].
  132. To ease maintenance, modifications are marked with the comment "Csm specific".
  133. Indeed, Csm model requires a custom generation sampling step:
  134. 1. Infer the backbone model to sample the first codebook token
  135. 2. Call generate on the depth decoder with the first codebook token as input_ids to sample the next codebook tokens
  136. 3. Use these generated codebook tokens as input_ids to sample the next first codebook token using the backbone model
  137. 4. Repeat until stopping criteria is met
  138. Csm supports two stopping criteria:
  139. - stop when the generated sequence is at max_length
  140. - stop when all the generated codebook tokens are the codebook_eos_token_id
  141. """
  142. # init values
  143. # *************** Csm specific ***************
  144. pad_token_id = self.config.codebook_pad_token_id
  145. has_eos_stopping_criteria = generation_config._eos_token_tensor is not None
  146. # ============================================
  147. output_attentions = generation_config.output_attentions
  148. output_hidden_states = generation_config.output_hidden_states
  149. output_scores = generation_config.output_scores
  150. output_logits = generation_config.output_logits
  151. return_dict_in_generate = generation_config.return_dict_in_generate
  152. do_sample = generation_config.do_sample
  153. # init attention / hidden states / scores tuples
  154. scores = () if (return_dict_in_generate and output_scores) else None
  155. raw_logits = () if (return_dict_in_generate and output_logits) else None
  156. decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
  157. decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
  158. # keep track of which sequences are already finished
  159. batch_size, cur_len = input_ids.shape[:2]
  160. this_peer_finished = False
  161. unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
  162. # *************** Csm specific ***************
  163. if input_ids.ndim == 2 and model_kwargs.get("inputs_embeds") is None:
  164. # in the case where the passed input_ids correspond to text tokens, i.e. don't have a third dimension for codebook ids,
  165. # we need to remove the input length to the MaxLengthCriteria stopping criteria has such input are not returned
  166. for criterion in stopping_criteria:
  167. if isinstance(criterion, MaxLengthCriteria):
  168. criterion.max_length -= cur_len
  169. # ============================================
  170. model_forward = (
  171. self.get_compiled_call(generation_config.compile_config)
  172. if self._valid_auto_compile_criteria(model_kwargs, generation_config)
  173. else self.__call__
  174. )
  175. # *************** Csm specific ***************
  176. model_kwargs.update({"output_hidden_states": True})
  177. prefill_consumed = False
  178. outputs = self._prefill(
  179. input_ids,
  180. generation_config,
  181. model_kwargs,
  182. is_first_iteration=not generation_config.is_assistant,
  183. )
  184. while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
  185. if prefill_consumed:
  186. next_sequence_length = 1 if model_kwargs["use_cache"] else None
  187. model_inputs = self.prepare_inputs_for_generation(
  188. input_ids, next_sequence_length=next_sequence_length, **model_kwargs
  189. )
  190. # prepare variable output controls (note: some models won't accept all output controls)
  191. model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
  192. outputs = model_forward(**model_inputs, return_dict=True)
  193. prefill_consumed = True
  194. # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
  195. model_kwargs = self._update_model_kwargs_for_generation(
  196. outputs,
  197. model_kwargs,
  198. )
  199. if synced_gpus and this_peer_finished:
  200. continue
  201. # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
  202. # (the clone itself is always small)
  203. next_token_logits = outputs.logits[:, -1, :].clone().float()
  204. next_token_logits = next_token_logits.to(input_ids.device)
  205. # pre-process distribution
  206. next_token_scores = logits_processor(input_ids, next_token_logits)
  207. # Store scores, attentions and hidden_states when required
  208. if return_dict_in_generate:
  209. if output_scores:
  210. scores += (next_token_scores,)
  211. if output_logits:
  212. raw_logits += (next_token_logits,)
  213. if output_attentions:
  214. decoder_attentions += (outputs.attentions,)
  215. if output_hidden_states:
  216. decoder_hidden_states += (outputs.hidden_states,)
  217. # token selection
  218. if do_sample:
  219. probs = nn.functional.softmax(next_token_scores, dim=-1)
  220. # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
  221. next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
  222. else:
  223. next_tokens = torch.argmax(next_token_scores, dim=-1)
  224. # *************** Csm specific ***************
  225. # infer the depth decoder
  226. first_codebook_ids = next_tokens[:, None]
  227. # adds place holder in position 0 that will be replaced by the backbone_last_hidden_state
  228. depth_decoder_input_ids = nn.functional.pad(first_codebook_ids, (1, 0), value=0)
  229. backbone_last_hidden_state = outputs.hidden_states[-1][:, -1, :]
  230. depth_decoder_outputs = self.depth_decoder.generate(
  231. input_ids=depth_decoder_input_ids, backbone_last_hidden_state=backbone_last_hidden_state.clone()
  232. )
  233. codebook_ids = (
  234. depth_decoder_outputs
  235. if isinstance(depth_decoder_outputs, torch.Tensor)
  236. else depth_decoder_outputs.sequences
  237. )
  238. # remove the place holder in position 0
  239. codebook_ids = codebook_ids[:, 1:]
  240. next_tokens = codebook_ids
  241. # finished sentences should have their next token be a padding token
  242. if has_eos_stopping_criteria:
  243. next_tokens = next_tokens * unfinished_sequences.unsqueeze(-1) + pad_token_id * (
  244. 1 - unfinished_sequences.unsqueeze(-1)
  245. )
  246. # update generated ids, model inputs, and length for next step
  247. if input_ids.ndim == 2:
  248. input_ids = next_tokens[:, None, :]
  249. else:
  250. input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
  251. # ============================================
  252. if streamer is not None:
  253. streamer.put(next_tokens.cpu())
  254. # *************** Csm specific ***************
  255. # for the eos stopping criteria, is it expected that the eos token is the same for each codebook !!!!
  256. unfinished_sequences = unfinished_sequences & ~(
  257. input_ids[:, -1, :-1] == self.config.codebook_eos_token_id
  258. ).all(-1)
  259. # ============================================
  260. unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
  261. this_peer_finished = unfinished_sequences.max() == 0
  262. cur_len += 1
  263. # This is needed to properly delete outputs.logits which may be very large for first iteration
  264. # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
  265. del outputs
  266. # *************** Csm specific ***************
  267. del depth_decoder_outputs
  268. # ============================================
  269. if streamer is not None:
  270. streamer.end()
  271. if return_dict_in_generate:
  272. return GenerateDecoderOnlyOutput(
  273. sequences=input_ids,
  274. scores=scores,
  275. logits=raw_logits,
  276. attentions=decoder_attentions,
  277. hidden_states=decoder_hidden_states,
  278. past_key_values=model_kwargs.get("past_key_values"),
  279. )
  280. else:
  281. return input_ids
  282. def generate(
  283. self,
  284. input_ids: torch.Tensor | None = None,
  285. input_values: torch.Tensor | None = None,
  286. input_values_cutoffs: torch.Tensor | None = None,
  287. generation_config: GenerationConfig | None = None,
  288. logits_processor: LogitsProcessorList | None = None,
  289. stopping_criteria: StoppingCriteriaList | None = None,
  290. synced_gpus: bool | None = None,
  291. streamer: Optional["BaseStreamer"] = None,
  292. output_audio: bool | None = False,
  293. **kwargs,
  294. ) -> GenerateNonBeamOutput | torch.LongTensor:
  295. r"""
  296. This method overrides [`~generation.utils.GenerationMixin.generate`] to match the specifics of the Csm model.
  297. Indeed, Csm model requires a custom generation sampling step:
  298. 1. Infer the backbone model to sample the first codebook token
  299. 2. Call generate on the depth decoder with the first codebook token as `input_ids` to sample the next codebook tokens
  300. 3. Use these generated codebook tokens as `input_ids` to sample the next first codebook token using the backbone model
  301. 4. Repeat until stopping criteria is met
  302. <Tip warning={true}>
  303. Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
  304. model's default generation configuration. You can override any `generation_config` by passing the corresponding
  305. parameters to generate(), e.g. `.generate(inputs, do_sample=True)`.
  306. </Tip>
  307. Parameters:
  308. inputs_ids (`torch.Tensor` of shape (batch_size, seq_length), *optional*):
  309. The sequence used as a prompt for the backbone model.
  310. input_values (`torch.Tensor` of shape (batch_size, channels, max_concatenated_audio_length), *optional*):
  311. The batched audio input values, where each batch entry contains the concatenation of all audio segments for that entry.
  312. These values will be encoded into codebook tokens using the codec model and merged with the text input ids provided in `input_ids`.
  313. input_values_cutoffs (`torch.Tensor` of shape (batch_size, max_num_audio), *optional*):
  314. Specify the end positions of audio segments within each batch entry, relative to the concatenated audio input.
  315. If a batch entry has fewer segments than the maximum, it is padded with -1. For example, in a batch of 2 sequences
  316. where the first contains 2 audio segments of length l1, and the second contains 1 audio segment of length l2,
  317. the input_values_cutoffs would be: [[l1, 2 * l1], [l2, -1]].
  318. generation_config ([`~generation.GenerationConfig`], *optional*):
  319. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  320. passed to generate matching the attributes of `generation_config` will override them. If
  321. `generation_config` is not provided, the default will be used, which has the following loading
  322. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  323. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  324. default values, whose documentation should be checked to parameterize generation.
  325. logits_processor (`LogitsProcessorList`, *optional*):
  326. Custom logits processors that complement the default logits processors built from arguments and
  327. generation config. If a logit processor is passed that is already created with the arguments or a
  328. generation config an error is thrown. This feature is intended for advanced users.
  329. stopping_criteria (`StoppingCriteriaList`, *optional*):
  330. Custom stopping criteria that complements the default stopping criteria built from arguments and a
  331. generation config. If a stopping criteria is passed that is already created with the arguments or a
  332. generation config an error is thrown. If your stopping criteria depends on the `scores` input, make
  333. sure you pass `return_dict_in_generate=True, output_scores=True` to `generate`. This feature is
  334. intended for advanced users.
  335. synced_gpus (`bool`, *optional*):
  336. Whether to continue running the while loop until max_length. Unless overridden, this flag will be set
  337. to `True` if using `FullyShardedDataParallel` or DeepSpeed ZeRO Stage 3 with multiple GPUs to avoid
  338. deadlocking if one GPU finishes generating before other GPUs. Otherwise, defaults to `False`.
  339. streamer (`BaseStreamer`, *optional*):
  340. Streamer object that will be used to stream the generated sequences. Generated tokens are passed
  341. through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
  342. output_audio (`bool`, *optional*):
  343. Whether to return the generated audio.
  344. kwargs (`dict[str, Any]`, *optional*):
  345. Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
  346. forwarded to the `forward` function of the model. Depth decoder specific kwargs should be prefixed with *depth_decoder_*.
  347. Return:
  348. [`CsmGenerateOutput`] or `torch.LongTensor` or `list[torch.FloatTensor]`: A [`CsmGenerateOutput`]
  349. (if `return_dict_in_generate=True` or when `config.return_dict_in_generate=True`) or a `torch.LongTensor` when `output_audio=False`
  350. or a `list[torch.FloatTensor]` otherwise.
  351. Example:
  352. ```python
  353. >>> from transformers import CsmProcessor, CsmForConditionalGeneration
  354. >>> from datasets import load_dataset, Audio
  355. >>> model_id = "sesame/csm-1b"
  356. >>> torch_device = "cuda" if torch.cuda.is_available() else "cpu"
  357. >>> processor = AutoProcessor.from_pretrained(model_id)
  358. >>> ds = load_dataset("hf-internal-testing/dailytalk-dummy", split="train")
  359. >>> # ensure the audio is 24kHz
  360. >>> ds = ds.cast_column("audio", Audio(sampling_rate=24000))
  361. >>> conversation = []
  362. >>> # prepare a conversation with text and corresponding audio
  363. >>> for text, audio, speaker_id in zip(ds[:4]["text"], ds[:4]["audio"], ds[:4]["speaker_id"]):
  364. ... conversation.append(
  365. ... {
  366. ... "role": f"{speaker_id}",
  367. ... "content": [{"type": "text", "text": text}, {"type": "audio", "path": audio["array"]}],
  368. ... }
  369. ... )
  370. >>> # text prompt
  371. >>> conversation.append({"role": f"{ds[4]['speaker_id']}", "content": [{"type": "text", "text": ds[4]["text"]}]})
  372. >>> inputs = processor.apply_chat_template(
  373. ... conversation,
  374. ... tokenize=True,
  375. ... return_dict=True,
  376. ... ).to(torch_device)
  377. >>> model = CsmForConditionalGeneration.from_pretrained(model_id, device_map=torch_device)
  378. >>> audio = model.generate(**inputs, output_audio=True)
  379. >>> processor.save_audio(audio, "output.wav")
  380. ```
  381. """
  382. generate_output = super().generate(
  383. input_ids=input_ids,
  384. input_values=input_values,
  385. input_values_cutoffs=input_values_cutoffs,
  386. generation_config=generation_config,
  387. logits_processor=logits_processor,
  388. stopping_criteria=stopping_criteria,
  389. synced_gpus=synced_gpus,
  390. streamer=streamer,
  391. **kwargs,
  392. )
  393. generate_returned_dict = not isinstance(generate_output, torch.Tensor)
  394. audio = None
  395. if output_audio:
  396. generated_audio_codes = generate_output.sequences if generate_returned_dict else generate_output
  397. # infer the codec model
  398. audio = []
  399. with torch.no_grad():
  400. # =======================================
  401. # TODO: @eustlb, this should be batched !!!
  402. # but requires making sure batched inference of the codec model works as intended
  403. for audio_codes_batch in generated_audio_codes:
  404. eos_idxs = (audio_codes_batch == self.config.codebook_eos_token_id).all(dim=-1).nonzero()
  405. if eos_idxs.numel() != 0:
  406. cutoff_idx = eos_idxs.min()
  407. else:
  408. cutoff_idx = audio_codes_batch.shape[0]
  409. audio_codes_batch = audio_codes_batch[:cutoff_idx]
  410. codec_decode_output = self.codec_model.decode(audio_codes_batch.transpose(0, 1).unsqueeze(0))
  411. audio.append(codec_decode_output.audio_values[0, 0])
  412. # =======================================
  413. if generate_returned_dict:
  414. return CsmGenerateOutput(audio=audio, **generate_output)
  415. elif output_audio:
  416. return audio
  417. else:
  418. return generate_output