generation_dia.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Callable
  15. from typing import Any, Optional
  16. import torch
  17. import torch.distributed as dist
  18. from ...generation.logits_process import (
  19. DiaClassifierFreeGuidanceLogitsProcessor,
  20. DiaEOSChannelFilterLogitsProcessor,
  21. DiaEOSDelayPatternLogitsProcessor,
  22. LogitsProcessorList,
  23. TemperatureLogitsWarper,
  24. )
  25. from ...generation.stopping_criteria import StoppingCriteriaList
  26. from ...generation.streamers import BaseStreamer
  27. from ...generation.utils import GenerateOutput, GenerationConfig, GenerationMixin, GenerationMode
  28. from ...integrations.deepspeed import is_deepspeed_zero3_enabled
  29. from ...integrations.fsdp import is_fsdp_managed_module
  30. from ...modeling_utils import PreTrainedModel
  31. from ...utils import logging
  32. logger = logging.get_logger(__name__)
  33. class DiaGenerationMixin(GenerationMixin):
  34. # Indicates CFG which needs preparation to be properly handled by repeats
  35. _uses_cfg = None
  36. def _get_logits_processor(
  37. self,
  38. generation_config: GenerationConfig,
  39. input_ids_seq_length: int | None = None,
  40. encoder_input_ids: torch.LongTensor | None = None,
  41. prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
  42. logits_processor: LogitsProcessorList | None = None,
  43. device: str | None = None,
  44. model_kwargs: dict[str, Any] | None = None,
  45. negative_prompt_ids: torch.Tensor | None = None,
  46. negative_prompt_attention_mask: torch.Tensor | None = None,
  47. ) -> LogitsProcessorList:
  48. # Need either custom order or custom processor instead
  49. # (Temporarily disabling those for the super function)
  50. original_guidance_scale = generation_config.guidance_scale
  51. original_temperature = generation_config.temperature
  52. generation_config.guidance_scale = None
  53. generation_config.temperature = None
  54. # Get base processors and those we can integrate easily
  55. custom_processors = LogitsProcessorList()
  56. if original_temperature is not None and original_temperature != 1.0:
  57. custom_processors.append(TemperatureLogitsWarper(original_temperature))
  58. custom_processors.append(
  59. DiaEOSChannelFilterLogitsProcessor(
  60. num_channels=len(self.config.delay_pattern),
  61. eos_token_id=self.config.decoder_config.eos_token_id,
  62. )
  63. )
  64. merged_processors = super()._get_logits_processor(
  65. generation_config=generation_config,
  66. input_ids_seq_length=input_ids_seq_length,
  67. encoder_input_ids=encoder_input_ids,
  68. prefix_allowed_tokens_fn=None,
  69. logits_processor=custom_processors,
  70. device=device,
  71. model_kwargs=model_kwargs,
  72. negative_prompt_ids=negative_prompt_ids,
  73. negative_prompt_attention_mask=negative_prompt_attention_mask,
  74. )
  75. # Custom processors we need at specific positions
  76. if original_guidance_scale is not None and original_guidance_scale != 1:
  77. cfg_processor = DiaClassifierFreeGuidanceLogitsProcessor(
  78. guidance_scale=original_guidance_scale,
  79. guidance_top_k=generation_config.top_k,
  80. )
  81. merged_processors.insert(0, cfg_processor)
  82. merged_processors.append(
  83. DiaEOSDelayPatternLogitsProcessor(
  84. delay_pattern=self.config.delay_pattern,
  85. eos_token_id=self.config.decoder_config.eos_token_id,
  86. max_generation_len=generation_config.max_length,
  87. device=device,
  88. )
  89. )
  90. # Enable temporarily disabled values back
  91. generation_config.guidance_scale = original_guidance_scale
  92. generation_config.temperature = original_temperature
  93. return merged_processors
  94. def _prepare_generation_config(
  95. self, generation_config: GenerationConfig | None, **kwargs: Any
  96. ) -> tuple[GenerationConfig, dict]:
  97. generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
  98. if generation_config.temperature is not None and generation_config.temperature < 1.0:
  99. logger.warning_once(
  100. f"temperature < 1.0 is not supported for Dia; clamping to 1.0 (got {generation_config.temperature})"
  101. )
  102. generation_config.temperature = 1.0
  103. # We allow generation up to max length + max delay pattern
  104. # (will revert back to max length after generation)
  105. generation_config.max_length += max(self.config.delay_pattern)
  106. # Internal flag to indicate CFG that needs to prepare unconditioned input
  107. self._uses_cfg = generation_config.guidance_scale is not None and generation_config.guidance_scale != 1
  108. return generation_config, model_kwargs
  109. def _prepare_model_inputs(
  110. self,
  111. inputs: torch.Tensor | None = None,
  112. bos_token_id: torch.Tensor | None = None,
  113. model_kwargs: dict[str, torch.Tensor] | None = None,
  114. ) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]:
  115. inputs, input_name, model_kwargs = super()._prepare_model_inputs(
  116. inputs=inputs,
  117. bos_token_id=bos_token_id,
  118. model_kwargs=model_kwargs,
  119. )
  120. # If CFG is requested we fill in the unconditioned parts
  121. if self._uses_cfg:
  122. unconditioned_inputs = torch.zeros_like(inputs)
  123. inputs = torch.cat([inputs, unconditioned_inputs], dim=0)
  124. if model_kwargs.get("attention_mask", None) is not None:
  125. model_kwargs["attention_mask"] = model_kwargs["attention_mask"].repeat(2, 1)
  126. return inputs, input_name, model_kwargs
  127. def _prepare_decoder_input_ids_for_generation(
  128. self,
  129. batch_size: int,
  130. model_input_name: str,
  131. model_kwargs: dict[str, torch.Tensor],
  132. decoder_start_token_id: torch.Tensor,
  133. device: torch.device | None = None,
  134. ) -> tuple[torch.LongTensor, dict[str, torch.Tensor]]:
  135. """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
  136. # 1. Check whether the user has defined `decoder_input_ids` and `decoder_attention_mask`; if not error out
  137. decoder_input_ids = decoder_attention_mask = None
  138. if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
  139. decoder_input_ids = model_kwargs.pop("decoder_input_ids")
  140. if model_kwargs is not None and "decoder_attention_mask" in model_kwargs:
  141. decoder_attention_mask = model_kwargs.pop("decoder_attention_mask")
  142. # We allow generating without preparation (no proper delay) but discourage it
  143. if decoder_input_ids is None or decoder_attention_mask is None:
  144. logger.warning_once(
  145. "In order to generate with Dia, we need the processed audio input: Got `decoder_input_ids`:"
  146. f" {decoder_input_ids is not None} and got `decoder_attention_mask`={decoder_attention_mask is not None}."
  147. f" This can be achieved via the [`DiaProcessor`] but now defaulting to non-delayed generation."
  148. )
  149. num_channels = self.config.decoder_config.num_channels
  150. real_batch_size = batch_size // 2 if self._uses_cfg else batch_size
  151. if decoder_input_ids is None:
  152. decoder_input_ids = torch.full(
  153. (real_batch_size, 1, num_channels), decoder_start_token_id, dtype=torch.long, device=device
  154. )
  155. decoder_attention_mask = torch.ones(
  156. size=(real_batch_size, decoder_input_ids.shape[1]), dtype=torch.long, device=device
  157. )
  158. # 2. Determine the valid input and what works as mask within the input
  159. delay_mask = decoder_input_ids.long()
  160. valid_input_size = (
  161. decoder_input_ids.shape[1]
  162. - (decoder_input_ids[:, :, 0] == self.config.decoder_config.pad_token_id).sum(dim=-1).max()
  163. )
  164. decoder_input_ids = delay_mask[:, :valid_input_size].transpose(1, 2).long()
  165. decoder_attention_mask = decoder_attention_mask[:, :valid_input_size].long()
  166. # 3. Overwrite into model kwargs
  167. model_kwargs["decoder_attention_mask"] = decoder_attention_mask
  168. model_kwargs["decoder_delay_mask"] = delay_mask
  169. return decoder_input_ids, model_kwargs
  170. def prepare_inputs_for_generation(
  171. self,
  172. input_ids,
  173. encoder_outputs=None, # Using this to easily get the batch size
  174. decoder_delay_mask=None,
  175. is_first_iteration: bool | None = False,
  176. **kwargs,
  177. ):
  178. # Reshape decoder input_ids to 3D to be compile friendly and to fit the expected model input shape
  179. batch_size = encoder_outputs[0].shape[0] // 2 if self._uses_cfg else encoder_outputs[0].shape[0]
  180. input_ids = input_ids.reshape(batch_size, self.config.decoder_config.num_channels, -1).transpose(1, 2)
  181. # Base method handles most things except CFG and the delay pattern mask
  182. model_inputs = super().prepare_inputs_for_generation(input_ids, encoder_outputs=encoder_outputs, **kwargs)
  183. # Post processing for CFG and overwriting via delay pattern mask
  184. # 1. Delay pattern mask -- force tokens if not allowed to predict (!= pad_token in mask)
  185. model_inputs["decoder_input_ids"] = self.apply_delay_mask(
  186. input_ids, self.config.decoder_config.pad_token_id, decoder_delay_mask
  187. )
  188. # Depending on cache usage we need to pass all or just one
  189. if model_inputs.get("use_cache", False) and not is_first_iteration:
  190. model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"][:, -1, :][:, None, :]
  191. # Be compile friendly
  192. model_inputs["decoder_input_ids"] = model_inputs["decoder_input_ids"].contiguous()
  193. # 2. Apply CFG duplication if needed
  194. if self._uses_cfg:
  195. for key in ["decoder_input_ids", "decoder_attention_mask", "decoder_position_ids"]:
  196. if model_inputs.get(key, None) is not None:
  197. # double first dimension and keep everything else the same
  198. repeat_pattern = tuple([2] + [1] * (model_inputs[key].ndim - 1))
  199. model_inputs[key] = model_inputs[key].repeat(*repeat_pattern)
  200. return model_inputs
  201. @staticmethod
  202. def apply_delay_mask(input_ids: torch.Tensor, pad_id: int, delay_mask: torch.Tensor | None) -> torch.Tensor:
  203. if delay_mask is None:
  204. return input_ids
  205. mask_len = min(input_ids.shape[1], delay_mask.shape[1])
  206. valid_mask = delay_mask[:, :mask_len, :]
  207. valid_input = input_ids[:, :mask_len, :]
  208. # Overwrite the respective parts of the input
  209. input_ids[:, :mask_len, :] = torch.where(valid_mask == pad_id, valid_input, valid_mask)
  210. return input_ids
  211. def _main_generate_loop(
  212. self,
  213. inputs: torch.Tensor | None = None,
  214. generation_config: GenerationConfig | None = None,
  215. logits_processor: LogitsProcessorList | None = None,
  216. stopping_criteria: StoppingCriteriaList | None = None,
  217. prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
  218. synced_gpus: bool | None = None,
  219. assistant_model: Optional["PreTrainedModel"] = None,
  220. streamer: Optional["BaseStreamer"] = None,
  221. negative_prompt_ids: torch.Tensor | None = None,
  222. negative_prompt_attention_mask: torch.Tensor | None = None,
  223. custom_generate: str | None = None,
  224. **kwargs,
  225. ):
  226. # ********** mostly taken from main generate function up to calling the different methods (see NOTE) **********
  227. # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
  228. generation_mode_kwargs = self._extract_generation_mode_kwargs(
  229. custom_generate,
  230. kwargs,
  231. synced_gpus,
  232. assistant_model,
  233. streamer,
  234. )
  235. generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
  236. generation_mode = generation_config.get_generation_mode(assistant_model)
  237. if generation_mode not in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
  238. raise ValueError(
  239. "Got incompatible mode for generation, should be one of greedy or sampling. "
  240. "Ensure that beam search is de-activated by setting `num_beams=1`."
  241. )
  242. self._validate_model_kwargs(model_kwargs.copy())
  243. self._validate_generation_mode(generation_mode, generation_config, generation_mode_kwargs)
  244. # 2. Set generation parameters if not already defined
  245. if synced_gpus is None:
  246. synced_gpus = (is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)) and dist.get_world_size() > 1
  247. logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  248. stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
  249. # 3. Define model inputs
  250. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  251. inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
  252. inputs, generation_config.bos_token_id, model_kwargs
  253. )
  254. batch_size = inputs_tensor.shape[0]
  255. device = inputs_tensor.device
  256. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
  257. # 4. Define other model kwargs
  258. if "encoder_outputs" not in model_kwargs:
  259. # if model is encoder decoder encoder_outputs are created and added to `model_kwargs`
  260. model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
  261. inputs_tensor, model_kwargs, model_input_name, generation_config
  262. )
  263. # 5. Prepare `input_ids` which will be used for auto-regressive generation
  264. input_ids, model_kwargs = self._prepare_decoder_input_ids_for_generation(
  265. batch_size=batch_size,
  266. model_input_name=model_input_name,
  267. model_kwargs=model_kwargs,
  268. decoder_start_token_id=generation_config._decoder_start_token_tensor,
  269. device=inputs_tensor.device,
  270. )
  271. if generation_config.token_healing:
  272. input_ids = self.heal_tokens(input_ids, generation_mode_kwargs.get("tokenizer"))
  273. if streamer is not None:
  274. streamer.put(input_ids.cpu())
  275. # 6. Prepare `max_length` depending on other stopping criteria.
  276. # NOTE: incorrect `input_ids.shape[1]` previously
  277. input_ids_length = input_ids.shape[-1]
  278. has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
  279. has_default_min_length = kwargs.get("min_length") is None and generation_config.min_length is not None
  280. generation_config = self._prepare_generated_length(
  281. generation_config=generation_config,
  282. has_default_max_length=has_default_max_length,
  283. has_default_min_length=has_default_min_length,
  284. model_input_name=model_input_name,
  285. inputs_tensor=inputs_tensor,
  286. input_ids_length=input_ids_length,
  287. )
  288. # If the model supports `logits_to_keep` in forward(), set it to 1 to avoid computing the whole
  289. # logit matrix. This can save a lot of memory during the first forward pass. Note that assisted decoding
  290. # dynamically overrides this value as it can need more than the last token logits
  291. if self._supports_logits_to_keep() and "logits_to_keep" not in model_kwargs:
  292. model_kwargs["logits_to_keep"] = 1
  293. self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
  294. # 7. Prepare the cache.
  295. # - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
  296. # - different models have a different cache name expected by the model (default = "past_key_values")
  297. # - `max_length`, prepared above, is used to determine the maximum cache length
  298. max_cache_length = generation_config.max_length - 1
  299. if (
  300. inputs_tensor.shape[1] != input_ids_length
  301. and model_input_name == "inputs_embeds"
  302. and not self.config.is_encoder_decoder
  303. ):
  304. max_cache_length += inputs_tensor.shape[1]
  305. self._prepare_cache_for_generation(
  306. generation_config, model_kwargs, generation_mode, batch_size, max_cache_length
  307. )
  308. # 8. prepare logits processors and stopping criteria
  309. prepared_logits_processor = self._get_logits_processor(
  310. generation_config=generation_config,
  311. input_ids_seq_length=input_ids_length,
  312. encoder_input_ids=inputs_tensor,
  313. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  314. logits_processor=logits_processor,
  315. device=inputs_tensor.device,
  316. model_kwargs=model_kwargs,
  317. negative_prompt_ids=negative_prompt_ids,
  318. negative_prompt_attention_mask=negative_prompt_attention_mask,
  319. )
  320. prepared_stopping_criteria = self._get_stopping_criteria(
  321. generation_config=generation_config,
  322. stopping_criteria=stopping_criteria,
  323. tokenizer=generation_mode_kwargs.get("tokenizer"),
  324. )
  325. # Set model_kwargs `use_cache` so we can use it later in forward runs
  326. model_kwargs["use_cache"] = generation_config.use_cache
  327. # ******************* taken from main generate function up to calling the different methods *******************
  328. # Prepare inner 2D logic in generation loop
  329. input_ids = input_ids.reshape(-1, input_ids.shape[-1])
  330. # 10. expand input_ids with `num_return_sequences` additional sequences per batch
  331. if generation_config.num_return_sequences > 1:
  332. raise ValueError("`num_return_sequences>1` is incompatible with Dia.")
  333. # 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
  334. return self._sample(
  335. input_ids,
  336. logits_processor=prepared_logits_processor,
  337. stopping_criteria=prepared_stopping_criteria,
  338. generation_config=generation_config,
  339. **generation_mode_kwargs,
  340. **model_kwargs,
  341. )
  342. @torch.no_grad()
  343. def generate(
  344. self,
  345. inputs: torch.Tensor | None = None,
  346. generation_config: GenerationConfig | None = None,
  347. logits_processor: LogitsProcessorList | None = None,
  348. stopping_criteria: StoppingCriteriaList | None = None,
  349. prefix_allowed_tokens_fn: Callable[[int, torch.Tensor], list[int]] | None = None,
  350. synced_gpus: bool | None = None,
  351. assistant_model: Optional["PreTrainedModel"] = None,
  352. streamer: Optional["BaseStreamer"] = None,
  353. negative_prompt_ids: torch.Tensor | None = None,
  354. negative_prompt_attention_mask: torch.Tensor | None = None,
  355. custom_generate: str | None = None,
  356. **kwargs,
  357. ) -> GenerateOutput | torch.LongTensor:
  358. # We expect the initial input ids to be the complete mask (delayed input)
  359. delay_mask = kwargs.get("decoder_input_ids")
  360. if delay_mask is not None:
  361. delay_mask = delay_mask.clone()
  362. output = self._main_generate_loop(
  363. inputs=inputs,
  364. generation_config=generation_config,
  365. logits_processor=logits_processor,
  366. stopping_criteria=stopping_criteria,
  367. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  368. synced_gpus=synced_gpus,
  369. assistant_model=assistant_model,
  370. streamer=streamer,
  371. negative_prompt_ids=negative_prompt_ids,
  372. negative_prompt_attention_mask=negative_prompt_attention_mask,
  373. custom_generate=custom_generate,
  374. **kwargs,
  375. )
  376. return_dict_in_generate = not isinstance(output, torch.Tensor)
  377. if return_dict_in_generate:
  378. output_sequences = output.sequences
  379. else:
  380. output_sequences = output
  381. # Reshape from 2D (bsz * channels, seq_len) to 3D (bsz, seq_len, channels)
  382. num_channels = self.config.decoder_config.num_channels
  383. bsz = output_sequences.shape[0] // num_channels
  384. output_sequences = output_sequences.reshape(bsz, num_channels, -1).transpose(1, 2)
  385. # Apply delay mask
  386. output_sequences = self.apply_delay_mask(output_sequences, self.config.decoder_config.pad_token_id, delay_mask)
  387. if return_dict_in_generate:
  388. output.sequences = output_sequences
  389. else:
  390. output = output_sequences
  391. return output