modular_dia.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch Dia model."""
  15. from collections.abc import Callable
  16. import torch
  17. from torch import nn
  18. from ... import initialization as init
  19. from ...cache_utils import DynamicCache, EncoderDecoderCache
  20. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  21. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import (
  24. BaseModelOutput,
  25. BaseModelOutputWithPastAndCrossAttentions,
  26. Seq2SeqLMOutput,
  27. Seq2SeqModelOutput,
  28. )
  29. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  30. from ...processing_utils import Unpack
  31. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  32. from ...utils.generic import merge_with_config_defaults
  33. from ...utils.output_capturing import capture_outputs
  34. from ..llama.modeling_llama import (
  35. LlamaAttention,
  36. LlamaRMSNorm,
  37. LlamaRotaryEmbedding,
  38. eager_attention_forward,
  39. )
  40. from ..phi3.modeling_phi3 import Phi3MLP
  41. from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
  42. from .generation_dia import DiaGenerationMixin
  43. logger = logging.get_logger(__name__)
  44. @auto_docstring
  45. class DiaPreTrainedModel(PreTrainedModel):
  46. config: DiaConfig
  47. base_model_prefix = "model"
  48. supports_gradient_checkpointing = True
  49. _supports_flash_attn = True
  50. _supports_sdpa = True
  51. _supports_flex_attn = True
  52. _can_compile_fullgraph = True
  53. main_input_name = "input_ids"
  54. _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
  55. def _init_weights(self, module):
  56. super()._init_weights(module)
  57. if isinstance(module, DiaMultiChannelEmbedding):
  58. offsets = torch.arange(self.config.num_channels, dtype=torch.long) * self.config.vocab_size
  59. init.copy_(module.offsets, offsets)
  60. class DiaMultiChannelEmbedding(nn.Module):
  61. """In order to efficiently compute the audio embedding from the 9 different channels,
  62. we vectorize the embedding process by using a single embedding layer and an offset.
  63. Example:
  64. - num_embeds = 4
  65. - vocab_size = 8
  66. - num_channels = 3
  67. We would have offsets = [0, 8, 16]
  68. If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
  69. then tokens = audio_codes + offsets
  70. = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
  71. This allows us to use a single embedding layer for all channels.
  72. """
  73. def __init__(self, config: DiaDecoderConfig):
  74. super().__init__()
  75. self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
  76. self.hidden_size = config.hidden_size
  77. self.num_channels = config.num_channels
  78. offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
  79. self.register_buffer("offsets", offsets, persistent=False)
  80. def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
  81. tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
  82. embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
  83. return embeds.sum(dim=2)
  84. class DiaMLP(Phi3MLP):
  85. pass
  86. class DiaRMSNorm(LlamaRMSNorm):
  87. pass
  88. class DiaRotaryEmbedding(LlamaRotaryEmbedding):
  89. pass
  90. class DiaSelfAttention(LlamaAttention):
  91. """Multi-headed attention from 'Attention Is All You Need' paper"""
  92. def __init__(self, config: DiaEncoderConfig | DiaDecoderConfig, layer_idx: int, is_causal: bool = False):
  93. nn.Module.__init__(self)
  94. self.config = config
  95. self.layer_idx = layer_idx
  96. self.hidden_size = config.hidden_size
  97. self.num_heads = self.config.num_attention_heads
  98. self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
  99. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  100. self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
  101. self.scaling = 1
  102. self.attention_dropout = 0.0
  103. self.is_causal = is_causal
  104. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  105. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  106. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  107. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  108. class DiaCrossAttention(nn.Module):
  109. """Multi-headed attention from 'Attention Is All You Need' paper"""
  110. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  111. super().__init__()
  112. self.config = config
  113. self.layer_idx = layer_idx
  114. self.hidden_size = config.hidden_size
  115. self.cross_hidden_size = config.cross_hidden_size
  116. self.num_heads = self.config.cross_num_attention_heads
  117. self.num_key_value_heads = self.config.cross_num_key_value_heads
  118. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  119. self.head_dim = config.cross_head_dim
  120. self.scaling = 1
  121. self.attention_dropout = 0.0
  122. self.is_causal = False
  123. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  124. self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  125. self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  126. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  127. def forward(
  128. self,
  129. hidden_states: torch.Tensor,
  130. cross_attention_states: torch.Tensor,
  131. attention_mask: torch.Tensor | None = None,
  132. past_key_values: EncoderDecoderCache | None = None,
  133. **kwargs: Unpack[FlashAttentionKwargs],
  134. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  135. input_shape = hidden_states.shape[:-1]
  136. hidden_shape = (*input_shape, -1, self.head_dim)
  137. cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
  138. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  139. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  140. if past_key_values is not None and is_updated:
  141. # reuse k,v, cross_attentions
  142. key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  143. value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  144. else:
  145. key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  146. value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  147. if past_key_values is not None:
  148. # save all states to the cache
  149. key_states, value_states = past_key_values.cross_attention_cache.update(
  150. key_states,
  151. value_states,
  152. self.layer_idx,
  153. )
  154. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  155. past_key_values.is_updated[self.layer_idx] = True
  156. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  157. self.config._attn_implementation, eager_attention_forward
  158. )
  159. attn_output, attn_weights = attention_interface(
  160. self,
  161. query_states,
  162. key_states,
  163. value_states,
  164. attention_mask,
  165. scaling=self.scaling,
  166. **kwargs,
  167. )
  168. attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
  169. attn_output = self.o_proj(attn_output)
  170. return attn_output, attn_weights
  171. class DiaEncoderLayer(GradientCheckpointingLayer):
  172. def __init__(self, config: DiaEncoderConfig, layer_idx: int):
  173. super().__init__()
  174. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  175. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
  176. self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  177. self.mlp = DiaMLP(config)
  178. def forward(
  179. self,
  180. hidden_states: torch.Tensor,
  181. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  182. attention_mask: torch.Tensor | None = None,
  183. **kwargs: Unpack[FlashAttentionKwargs],
  184. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  185. residual = hidden_states
  186. normed_states = self.pre_sa_norm(hidden_states)
  187. self_attn_output, _ = self.self_attention(
  188. normed_states,
  189. position_embeddings=position_embeddings,
  190. attention_mask=attention_mask,
  191. **kwargs,
  192. )
  193. hidden_states = residual + self_attn_output
  194. residual = hidden_states
  195. normed_states = self.post_sa_norm(hidden_states)
  196. mlp_out = self.mlp(normed_states)
  197. hidden_states = residual + mlp_out
  198. return hidden_states
  199. class DiaEncoder(DiaPreTrainedModel):
  200. _can_record_outputs = {
  201. "hidden_states": DiaEncoderLayer,
  202. "attentions": DiaSelfAttention,
  203. }
  204. def __init__(self, config: DiaEncoderConfig):
  205. super().__init__(config)
  206. self.config = config
  207. self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  208. self.layers = nn.ModuleList(
  209. [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  210. )
  211. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  212. self.rotary_emb = DiaRotaryEmbedding(config=config)
  213. self.post_init()
  214. @merge_with_config_defaults
  215. @capture_outputs
  216. @auto_docstring
  217. def forward(
  218. self,
  219. input_ids: torch.Tensor,
  220. attention_mask: torch.Tensor | None = None,
  221. **kwargs: Unpack[TransformersKwargs],
  222. ) -> BaseModelOutput:
  223. hidden_states = self.embedding(input_ids)
  224. # RoPE
  225. # Note: We expect right padding and hence always generate
  226. # the position ids on the fly to reduce preparation overhead
  227. position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
  228. attention_mask = create_bidirectional_mask(
  229. config=self.config,
  230. inputs_embeds=hidden_states,
  231. attention_mask=attention_mask,
  232. )
  233. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  234. for encoder_layer in self.layers:
  235. hidden_states = encoder_layer(
  236. hidden_states,
  237. attention_mask=attention_mask,
  238. position_ids=position_ids,
  239. position_embeddings=position_embeddings,
  240. **kwargs,
  241. )
  242. hidden_states = self.norm(hidden_states)
  243. return BaseModelOutput(last_hidden_state=hidden_states)
  244. class DiaDecoderLayer(GradientCheckpointingLayer):
  245. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  246. super().__init__()
  247. self.embed_dim = config.hidden_size
  248. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
  249. self.cross_attention = DiaCrossAttention(config, layer_idx)
  250. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  251. self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  252. self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  253. self.mlp = DiaMLP(config)
  254. def forward(
  255. self,
  256. hidden_states: torch.Tensor,
  257. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  258. attention_mask: torch.Tensor | None = None,
  259. encoder_hidden_states: torch.Tensor | None = None,
  260. encoder_attention_mask: torch.Tensor | None = None,
  261. past_key_values: EncoderDecoderCache | None = None,
  262. **kwargs,
  263. ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
  264. self_attn_cache = past_key_values
  265. if isinstance(self_attn_cache, EncoderDecoderCache):
  266. self_attn_cache = self_attn_cache.self_attention_cache
  267. residual = hidden_states
  268. normed_states = self.pre_sa_norm(hidden_states)
  269. self_attn_output, _ = self.self_attention(
  270. normed_states,
  271. position_embeddings,
  272. attention_mask,
  273. # Needs to be an arg in order to function properly
  274. # on inplace operations to be carried (e.g. compile)
  275. self_attn_cache,
  276. **kwargs,
  277. )
  278. hidden_states = residual + self_attn_output
  279. residual = hidden_states
  280. normed_states = self.pre_ca_norm(hidden_states)
  281. cross_states, _ = self.cross_attention(
  282. normed_states,
  283. encoder_hidden_states,
  284. attention_mask=encoder_attention_mask,
  285. past_key_values=past_key_values,
  286. **kwargs,
  287. )
  288. hidden_states = residual + cross_states
  289. residual = hidden_states
  290. normed_states = self.pre_mlp_norm(hidden_states)
  291. mlp_out = self.mlp(normed_states)
  292. hidden_states = residual + mlp_out
  293. return hidden_states
  294. class DiaDecoder(DiaPreTrainedModel):
  295. """Transformer Decoder Stack using DenseGeneral."""
  296. _can_record_outputs = {
  297. "hidden_states": DiaDecoderLayer,
  298. "attentions": [DiaSelfAttention, DiaCrossAttention],
  299. }
  300. def __init__(self, config: DiaDecoderConfig):
  301. super().__init__(config)
  302. self.num_channels = config.num_channels
  303. self.vocab_size = config.vocab_size
  304. self.embeddings = DiaMultiChannelEmbedding(config)
  305. self.layers = nn.ModuleList(
  306. [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  307. )
  308. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  309. self.rotary_emb = DiaRotaryEmbedding(config=config)
  310. self.post_init()
  311. @merge_with_config_defaults
  312. @capture_outputs
  313. @auto_docstring
  314. def forward(
  315. self,
  316. input_ids: torch.Tensor,
  317. position_ids: torch.LongTensor | None = None,
  318. attention_mask: torch.Tensor | None = None,
  319. encoder_hidden_states: torch.FloatTensor | None = None,
  320. encoder_attention_mask: torch.LongTensor | None = None,
  321. past_key_values: EncoderDecoderCache | None = None,
  322. **kwargs: Unpack[TransformersKwargs],
  323. ) -> BaseModelOutputWithPastAndCrossAttentions | tuple:
  324. r"""
  325. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
  326. The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
  327. [What are input IDs?](../glossary#input-ids)
  328. """
  329. batch_size, seq_length = input_ids.size()[:-1]
  330. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  331. if position_ids is None:
  332. position_ids = torch.arange(seq_length, device=input_ids.device) + past_key_values_length
  333. position_ids = position_ids.unsqueeze(0)
  334. # RoPE
  335. hidden_states = self.embeddings(input_ids)
  336. if attention_mask is None and not is_torchdynamo_compiling():
  337. # required mask seq length can be calculated via length of past cache
  338. mask_seq_length = past_key_values_length + seq_length
  339. attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
  340. attention_mask = create_causal_mask(
  341. config=self.config,
  342. inputs_embeds=hidden_states,
  343. attention_mask=attention_mask,
  344. past_key_values=past_key_values,
  345. )
  346. encoder_attention_mask = create_bidirectional_mask(
  347. config=self.config,
  348. inputs_embeds=hidden_states,
  349. attention_mask=encoder_attention_mask,
  350. encoder_hidden_states=encoder_hidden_states,
  351. )
  352. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  353. for layer in self.layers:
  354. hidden_states = layer(
  355. hidden_states,
  356. # Needs to be an arg in order to function properly
  357. # on inplace operations to be carried (e.g. compile)
  358. position_embeddings,
  359. attention_mask,
  360. encoder_hidden_states,
  361. encoder_attention_mask=encoder_attention_mask,
  362. past_key_values=past_key_values,
  363. position_ids=position_ids,
  364. **kwargs,
  365. )
  366. hidden_states = self.norm(hidden_states)
  367. return BaseModelOutputWithPastAndCrossAttentions(
  368. last_hidden_state=hidden_states,
  369. past_key_values=past_key_values,
  370. )
  371. @auto_docstring(
  372. custom_intro="""
  373. The bare Dia model outputting raw hidden-states without any specific head on top.
  374. """
  375. )
  376. class DiaModel(DiaPreTrainedModel):
  377. def __init__(self, config: DiaConfig):
  378. super().__init__(config)
  379. self.config = config
  380. self.encoder = DiaEncoder(config.encoder_config)
  381. self.decoder = DiaDecoder(config.decoder_config)
  382. self.post_init()
  383. @auto_docstring
  384. @can_return_tuple
  385. def forward(
  386. self,
  387. input_ids: torch.LongTensor | None = None,
  388. attention_mask: torch.LongTensor | None = None,
  389. decoder_input_ids: torch.LongTensor | None = None,
  390. decoder_position_ids: torch.LongTensor | None = None,
  391. decoder_attention_mask: torch.LongTensor | None = None,
  392. encoder_outputs: BaseModelOutput | tuple | None = None,
  393. past_key_values: EncoderDecoderCache | None = None,
  394. use_cache: bool | None = None,
  395. **kwargs,
  396. ) -> tuple | Seq2SeqModelOutput:
  397. r"""
  398. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  399. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  400. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  401. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  402. tened audio logits which are used to calculate the loss.
  403. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  404. Dia to calculate embeddings and subsequent steps more efficiently.
  405. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  406. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  407. [`DiaProcessor.__call__`] for more details.
  408. [What are decoder input IDs?](../glossary#decoder-input-ids)
  409. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  410. Indices of positions of each input sequence tokens in the position embeddings.
  411. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  412. [What are position IDs?](../glossary#position-ids)
  413. """
  414. if input_ids is None and encoder_outputs is None:
  415. raise ValueError(
  416. "You should either provide text ids or the cached text encodings. Neither has been found."
  417. )
  418. if self.is_gradient_checkpointing and self.training:
  419. if use_cache:
  420. logger.warning_once(
  421. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  422. )
  423. use_cache = False
  424. if use_cache and past_key_values is None:
  425. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  426. if encoder_outputs is None:
  427. encoder_outputs = self.encoder(
  428. input_ids=input_ids,
  429. attention_mask=attention_mask,
  430. **kwargs,
  431. )
  432. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
  433. elif not isinstance(encoder_outputs, BaseModelOutput):
  434. encoder_outputs = BaseModelOutput(
  435. last_hidden_state=encoder_outputs[0],
  436. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  437. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  438. )
  439. # On default we initialize the decoder with bos tokens if nothing has been provided
  440. bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
  441. if decoder_input_ids is None:
  442. decoder_input_ids = torch.full(
  443. size=(bsz, 1, channels), fill_value=self.config.decoder_config.bos_token_id, device=self.device
  444. )
  445. # Ensure 3D
  446. if decoder_input_ids.ndim == 2:
  447. decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
  448. decoder_outputs = self.decoder(
  449. input_ids=decoder_input_ids,
  450. position_ids=decoder_position_ids,
  451. attention_mask=decoder_attention_mask,
  452. encoder_hidden_states=encoder_outputs[0],
  453. encoder_attention_mask=attention_mask,
  454. past_key_values=past_key_values,
  455. use_cache=use_cache,
  456. **kwargs,
  457. )
  458. return Seq2SeqModelOutput(
  459. last_hidden_state=decoder_outputs.last_hidden_state,
  460. past_key_values=decoder_outputs.past_key_values,
  461. decoder_hidden_states=decoder_outputs.hidden_states,
  462. decoder_attentions=decoder_outputs.attentions,
  463. cross_attentions=decoder_outputs.cross_attentions,
  464. encoder_last_hidden_state=encoder_outputs[0],
  465. encoder_hidden_states=encoder_outputs.hidden_states,
  466. encoder_attentions=encoder_outputs.attentions,
  467. )
  468. @auto_docstring(
  469. custom_intro="""
  470. The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
  471. """
  472. )
  473. class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
  474. base_model_prefix = "model"
  475. output_modalities = ("audio",)
  476. def __init__(self, config: DiaConfig):
  477. super().__init__(config)
  478. self.config = config
  479. self.model = DiaModel(config)
  480. self.num_channels = config.decoder_config.num_channels
  481. self.vocab_size = config.decoder_config.vocab_size
  482. self.logits_dense = nn.Linear(
  483. config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
  484. )
  485. self.loss_type = "ForMaskedLM"
  486. # Initialize weights and apply final processing
  487. self.post_init()
  488. @auto_docstring
  489. @can_return_tuple
  490. def forward(
  491. self,
  492. input_ids: torch.LongTensor | None = None,
  493. attention_mask: torch.LongTensor | None = None,
  494. decoder_input_ids: torch.LongTensor | None = None,
  495. decoder_position_ids: torch.LongTensor | None = None,
  496. decoder_attention_mask: torch.LongTensor | None = None,
  497. encoder_outputs: BaseModelOutput | tuple | None = None,
  498. past_key_values: EncoderDecoderCache | None = None,
  499. use_cache: bool | None = None,
  500. labels: torch.LongTensor | None = None,
  501. **kwargs,
  502. ) -> tuple | Seq2SeqLMOutput:
  503. r"""
  504. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  505. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  506. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  507. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  508. tened audio logits which are used to calculate the loss.
  509. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  510. Dia to calculate embeddings and subsequent steps more efficiently.
  511. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  512. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  513. [`DiaProcessor.__call__`] for more details.
  514. [What are decoder input IDs?](../glossary#decoder-input-ids)
  515. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  516. Indices of positions of each input sequence tokens in the position embeddings.
  517. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  518. [What are position IDs?](../glossary#position-ids)
  519. labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
  520. Labels for computing the masked language modeling loss. Indices should either be in
  521. `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
  522. are ignored (masked).
  523. """
  524. outputs = self.model(
  525. input_ids=input_ids,
  526. attention_mask=attention_mask,
  527. decoder_input_ids=decoder_input_ids,
  528. decoder_position_ids=decoder_position_ids,
  529. decoder_attention_mask=decoder_attention_mask,
  530. encoder_outputs=encoder_outputs,
  531. past_key_values=past_key_values,
  532. use_cache=use_cache,
  533. **kwargs,
  534. )
  535. last_hidden_state = outputs[0]
  536. batch_size = last_hidden_state.shape[0]
  537. # 3D <-> 2D makes it necessary to prioritize channel dim
  538. audio_logits = (
  539. self.logits_dense(last_hidden_state)
  540. .view((batch_size, -1, self.num_channels, self.vocab_size))
  541. .transpose(1, 2)
  542. .contiguous()
  543. .view(batch_size * self.num_channels, -1, self.vocab_size)
  544. )
  545. loss = None
  546. if labels is not None:
  547. loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  548. return Seq2SeqLMOutput(
  549. loss=loss,
  550. logits=audio_logits,
  551. past_key_values=outputs.past_key_values,
  552. decoder_hidden_states=outputs.decoder_hidden_states,
  553. decoder_attentions=outputs.decoder_attentions,
  554. cross_attentions=outputs.cross_attentions,
  555. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  556. encoder_hidden_states=outputs.encoder_hidden_states,
  557. encoder_attentions=outputs.encoder_attentions,
  558. )
  559. __all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]