modeling_dia.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/dia/modular_dia.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_dia.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The Nari Labs and HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. from torch import nn
  24. from ... import initialization as init
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  27. from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
  28. from ...masking_utils import create_bidirectional_mask, create_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. BaseModelOutputWithPastAndCrossAttentions,
  34. Seq2SeqLMOutput,
  35. Seq2SeqModelOutput,
  36. )
  37. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  38. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  39. from ...processing_utils import Unpack
  40. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
  41. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  42. from ...utils.output_capturing import capture_outputs
  43. from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig
  44. from .generation_dia import DiaGenerationMixin
  45. logger = logging.get_logger(__name__)
  46. @auto_docstring
  47. class DiaPreTrainedModel(PreTrainedModel):
  48. config: DiaConfig
  49. base_model_prefix = "model"
  50. supports_gradient_checkpointing = True
  51. _supports_flash_attn = True
  52. _supports_sdpa = True
  53. _supports_flex_attn = True
  54. _can_compile_fullgraph = True
  55. main_input_name = "input_ids"
  56. _no_split_modules = ["DiaEncoderLayer", "DiaDecoderLayer"]
  57. def _init_weights(self, module):
  58. super()._init_weights(module)
  59. if isinstance(module, DiaMultiChannelEmbedding):
  60. offsets = torch.arange(self.config.num_channels, dtype=torch.long) * self.config.vocab_size
  61. init.copy_(module.offsets, offsets)
  62. class DiaMultiChannelEmbedding(nn.Module):
  63. """In order to efficiently compute the audio embedding from the 9 different channels,
  64. we vectorize the embedding process by using a single embedding layer and an offset.
  65. Example:
  66. - num_embeds = 4
  67. - vocab_size = 8
  68. - num_channels = 3
  69. We would have offsets = [0, 8, 16]
  70. If audio_codes = [0, 1, 2, 3], [1, 3, 4, 7], [5, 6, 7, 8],
  71. then tokens = audio_codes + offsets
  72. = [0, 1, 2, 3, 9, 11, 12, 15, 21, 22, 23, 24]
  73. This allows us to use a single embedding layer for all channels.
  74. """
  75. def __init__(self, config: DiaDecoderConfig):
  76. super().__init__()
  77. self.embed = nn.Embedding(config.vocab_size * config.num_channels, config.hidden_size)
  78. self.hidden_size = config.hidden_size
  79. self.num_channels = config.num_channels
  80. offsets = torch.arange(config.num_channels, dtype=torch.long) * config.vocab_size # (C,)
  81. self.register_buffer("offsets", offsets, persistent=False)
  82. def forward(self, audio_codes: torch.Tensor) -> torch.Tensor:
  83. tokens = (audio_codes + self.offsets.to(audio_codes.device)).squeeze(1)
  84. embeds = self.embed(tokens).view(tokens.shape[0], audio_codes.shape[1], -1, self.hidden_size)
  85. return embeds.sum(dim=2)
  86. class DiaMLP(nn.Module):
  87. def __init__(self, config):
  88. super().__init__()
  89. self.config = config
  90. self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
  91. self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
  92. self.activation_fn = ACT2FN[config.hidden_act]
  93. def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  94. up_states = self.gate_up_proj(hidden_states)
  95. gate, up_states = up_states.chunk(2, dim=-1)
  96. up_states = up_states * self.activation_fn(gate)
  97. return self.down_proj(up_states)
  98. @use_kernel_forward_from_hub("RMSNorm")
  99. class DiaRMSNorm(nn.Module):
  100. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  101. """
  102. DiaRMSNorm is equivalent to T5LayerNorm
  103. """
  104. super().__init__()
  105. self.weight = nn.Parameter(torch.ones(hidden_size))
  106. self.variance_epsilon = eps
  107. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  108. input_dtype = hidden_states.dtype
  109. hidden_states = hidden_states.to(torch.float32)
  110. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  111. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  112. return self.weight * hidden_states.to(input_dtype)
  113. def extra_repr(self):
  114. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  115. class DiaRotaryEmbedding(nn.Module):
  116. inv_freq: torch.Tensor # fix linting for `register_buffer`
  117. def __init__(self, config: DiaConfig, device=None):
  118. super().__init__()
  119. self.max_seq_len_cached = config.max_position_embeddings
  120. self.original_max_seq_len = config.max_position_embeddings
  121. self.config = config
  122. self.rope_type = self.config.rope_parameters["rope_type"]
  123. rope_init_fn: Callable = self.compute_default_rope_parameters
  124. if self.rope_type != "default":
  125. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  126. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  127. self.register_buffer("inv_freq", inv_freq, persistent=False)
  128. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  129. @staticmethod
  130. def compute_default_rope_parameters(
  131. config: DiaConfig | None = None,
  132. device: Optional["torch.device"] = None,
  133. seq_len: int | None = None,
  134. ) -> tuple["torch.Tensor", float]:
  135. """
  136. Computes the inverse frequencies according to the original RoPE implementation
  137. Args:
  138. config ([`~transformers.PreTrainedConfig`]):
  139. The model configuration.
  140. device (`torch.device`):
  141. The device to use for initialization of the inverse frequencies.
  142. seq_len (`int`, *optional*):
  143. The current sequence length. Unused for this type of RoPE.
  144. Returns:
  145. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  146. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  147. """
  148. base = config.rope_parameters["rope_theta"]
  149. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  150. attention_factor = 1.0 # Unused in this type of RoPE
  151. # Compute the inverse frequencies
  152. inv_freq = 1.0 / (
  153. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  154. )
  155. return inv_freq, attention_factor
  156. @torch.no_grad()
  157. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  158. def forward(self, x, position_ids):
  159. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
  160. position_ids_expanded = position_ids[:, None, :].float()
  161. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  162. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  163. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  164. emb = torch.cat((freqs, freqs), dim=-1)
  165. cos = emb.cos() * self.attention_scaling
  166. sin = emb.sin() * self.attention_scaling
  167. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  168. def rotate_half(x):
  169. """Rotates half the hidden dims of the input."""
  170. x1 = x[..., : x.shape[-1] // 2]
  171. x2 = x[..., x.shape[-1] // 2 :]
  172. return torch.cat((-x2, x1), dim=-1)
  173. @use_kernel_func_from_hub("rotary_pos_emb")
  174. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  175. """Applies Rotary Position Embedding to the query and key tensors.
  176. Args:
  177. q (`torch.Tensor`): The query tensor.
  178. k (`torch.Tensor`): The key tensor.
  179. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  180. sin (`torch.Tensor`): The sine part of the rotary embedding.
  181. unsqueeze_dim (`int`, *optional*, defaults to 1):
  182. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  183. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  184. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  185. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  186. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  187. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  188. Returns:
  189. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  190. """
  191. cos = cos.unsqueeze(unsqueeze_dim)
  192. sin = sin.unsqueeze(unsqueeze_dim)
  193. q_embed = (q * cos) + (rotate_half(q) * sin)
  194. k_embed = (k * cos) + (rotate_half(k) * sin)
  195. return q_embed, k_embed
  196. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  197. """
  198. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  199. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  200. """
  201. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  202. if n_rep == 1:
  203. return hidden_states
  204. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  205. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  206. def eager_attention_forward(
  207. module: nn.Module,
  208. query: torch.Tensor,
  209. key: torch.Tensor,
  210. value: torch.Tensor,
  211. attention_mask: torch.Tensor | None,
  212. scaling: float,
  213. dropout: float = 0.0,
  214. **kwargs: Unpack[TransformersKwargs],
  215. ):
  216. key_states = repeat_kv(key, module.num_key_value_groups)
  217. value_states = repeat_kv(value, module.num_key_value_groups)
  218. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  219. if attention_mask is not None:
  220. attn_weights = attn_weights + attention_mask
  221. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  222. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  223. attn_output = torch.matmul(attn_weights, value_states)
  224. attn_output = attn_output.transpose(1, 2).contiguous()
  225. return attn_output, attn_weights
  226. @use_kernelized_func(apply_rotary_pos_emb)
  227. class DiaSelfAttention(nn.Module):
  228. """Multi-headed attention from 'Attention Is All You Need' paper"""
  229. def __init__(self, config: DiaEncoderConfig | DiaDecoderConfig, layer_idx: int, is_causal: bool = False):
  230. super().__init__()
  231. self.config = config
  232. self.layer_idx = layer_idx
  233. self.hidden_size = config.hidden_size
  234. self.num_heads = self.config.num_attention_heads
  235. self.num_key_value_heads = self.config.num_key_value_heads or self.num_heads
  236. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  237. self.head_dim = getattr(config, "head_dim", config.hidden_size // self.num_heads)
  238. self.scaling = 1
  239. self.attention_dropout = 0.0
  240. self.is_causal = is_causal
  241. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  242. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  243. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  244. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  249. attention_mask: torch.Tensor | None = None,
  250. past_key_values: Cache | None = None,
  251. **kwargs: Unpack[TransformersKwargs],
  252. ) -> tuple[torch.Tensor, torch.Tensor]:
  253. input_shape = hidden_states.shape[:-1]
  254. hidden_shape = (*input_shape, -1, self.head_dim)
  255. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  256. key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  257. value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  258. cos, sin = position_embeddings
  259. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  260. if past_key_values is not None:
  261. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  262. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  263. self.config._attn_implementation, eager_attention_forward
  264. )
  265. attn_output, attn_weights = attention_interface(
  266. self,
  267. query_states,
  268. key_states,
  269. value_states,
  270. attention_mask,
  271. dropout=0.0 if not self.training else self.attention_dropout,
  272. scaling=self.scaling,
  273. **kwargs,
  274. )
  275. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  276. attn_output = self.o_proj(attn_output)
  277. return attn_output, attn_weights
  278. class DiaCrossAttention(nn.Module):
  279. """Multi-headed attention from 'Attention Is All You Need' paper"""
  280. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  281. super().__init__()
  282. self.config = config
  283. self.layer_idx = layer_idx
  284. self.hidden_size = config.hidden_size
  285. self.cross_hidden_size = config.cross_hidden_size
  286. self.num_heads = self.config.cross_num_attention_heads
  287. self.num_key_value_heads = self.config.cross_num_key_value_heads
  288. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  289. self.head_dim = config.cross_head_dim
  290. self.scaling = 1
  291. self.attention_dropout = 0.0
  292. self.is_causal = False
  293. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  294. self.k_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  295. self.v_proj = nn.Linear(self.cross_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  296. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  297. def forward(
  298. self,
  299. hidden_states: torch.Tensor,
  300. cross_attention_states: torch.Tensor,
  301. attention_mask: torch.Tensor | None = None,
  302. past_key_values: EncoderDecoderCache | None = None,
  303. **kwargs: Unpack[FlashAttentionKwargs],
  304. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  305. input_shape = hidden_states.shape[:-1]
  306. hidden_shape = (*input_shape, -1, self.head_dim)
  307. cross_shape = (*cross_attention_states.shape[:-1], -1, self.head_dim)
  308. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  309. is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
  310. if past_key_values is not None and is_updated:
  311. # reuse k,v, cross_attentions
  312. key_states = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
  313. value_states = past_key_values.cross_attention_cache.layers[self.layer_idx].values
  314. else:
  315. key_states = self.k_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  316. value_states = self.v_proj(cross_attention_states).view(cross_shape).transpose(1, 2)
  317. if past_key_values is not None:
  318. # save all states to the cache
  319. key_states, value_states = past_key_values.cross_attention_cache.update(
  320. key_states,
  321. value_states,
  322. self.layer_idx,
  323. )
  324. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  325. past_key_values.is_updated[self.layer_idx] = True
  326. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  327. self.config._attn_implementation, eager_attention_forward
  328. )
  329. attn_output, attn_weights = attention_interface(
  330. self,
  331. query_states,
  332. key_states,
  333. value_states,
  334. attention_mask,
  335. scaling=self.scaling,
  336. **kwargs,
  337. )
  338. attn_output = attn_output.reshape((*input_shape, -1)).contiguous()
  339. attn_output = self.o_proj(attn_output)
  340. return attn_output, attn_weights
  341. class DiaEncoderLayer(GradientCheckpointingLayer):
  342. def __init__(self, config: DiaEncoderConfig, layer_idx: int):
  343. super().__init__()
  344. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  345. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=False)
  346. self.post_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  347. self.mlp = DiaMLP(config)
  348. def forward(
  349. self,
  350. hidden_states: torch.Tensor,
  351. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  352. attention_mask: torch.Tensor | None = None,
  353. **kwargs: Unpack[FlashAttentionKwargs],
  354. ) -> tuple[torch.Tensor, torch.Tensor | None]:
  355. residual = hidden_states
  356. normed_states = self.pre_sa_norm(hidden_states)
  357. self_attn_output, _ = self.self_attention(
  358. normed_states,
  359. position_embeddings=position_embeddings,
  360. attention_mask=attention_mask,
  361. **kwargs,
  362. )
  363. hidden_states = residual + self_attn_output
  364. residual = hidden_states
  365. normed_states = self.post_sa_norm(hidden_states)
  366. mlp_out = self.mlp(normed_states)
  367. hidden_states = residual + mlp_out
  368. return hidden_states
  369. class DiaEncoder(DiaPreTrainedModel):
  370. _can_record_outputs = {
  371. "hidden_states": DiaEncoderLayer,
  372. "attentions": DiaSelfAttention,
  373. }
  374. def __init__(self, config: DiaEncoderConfig):
  375. super().__init__(config)
  376. self.config = config
  377. self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
  378. self.layers = nn.ModuleList(
  379. [DiaEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  380. )
  381. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  382. self.rotary_emb = DiaRotaryEmbedding(config=config)
  383. self.post_init()
  384. @merge_with_config_defaults
  385. @capture_outputs
  386. @auto_docstring
  387. def forward(
  388. self,
  389. input_ids: torch.Tensor,
  390. attention_mask: torch.Tensor | None = None,
  391. **kwargs: Unpack[TransformersKwargs],
  392. ) -> BaseModelOutput:
  393. hidden_states = self.embedding(input_ids)
  394. # RoPE
  395. # Note: We expect right padding and hence always generate
  396. # the position ids on the fly to reduce preparation overhead
  397. position_ids = torch.arange(input_ids.shape[-1], device=input_ids.device)[None, :]
  398. attention_mask = create_bidirectional_mask(
  399. config=self.config,
  400. inputs_embeds=hidden_states,
  401. attention_mask=attention_mask,
  402. )
  403. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  404. for encoder_layer in self.layers:
  405. hidden_states = encoder_layer(
  406. hidden_states,
  407. attention_mask=attention_mask,
  408. position_ids=position_ids,
  409. position_embeddings=position_embeddings,
  410. **kwargs,
  411. )
  412. hidden_states = self.norm(hidden_states)
  413. return BaseModelOutput(last_hidden_state=hidden_states)
  414. class DiaDecoderLayer(GradientCheckpointingLayer):
  415. def __init__(self, config: DiaDecoderConfig, layer_idx: int):
  416. super().__init__()
  417. self.embed_dim = config.hidden_size
  418. self.self_attention = DiaSelfAttention(config, layer_idx, is_causal=True)
  419. self.cross_attention = DiaCrossAttention(config, layer_idx)
  420. self.pre_sa_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  421. self.pre_ca_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  422. self.pre_mlp_norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  423. self.mlp = DiaMLP(config)
  424. def forward(
  425. self,
  426. hidden_states: torch.Tensor,
  427. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  428. attention_mask: torch.Tensor | None = None,
  429. encoder_hidden_states: torch.Tensor | None = None,
  430. encoder_attention_mask: torch.Tensor | None = None,
  431. past_key_values: EncoderDecoderCache | None = None,
  432. **kwargs,
  433. ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
  434. self_attn_cache = past_key_values
  435. if isinstance(self_attn_cache, EncoderDecoderCache):
  436. self_attn_cache = self_attn_cache.self_attention_cache
  437. residual = hidden_states
  438. normed_states = self.pre_sa_norm(hidden_states)
  439. self_attn_output, _ = self.self_attention(
  440. normed_states,
  441. position_embeddings,
  442. attention_mask,
  443. # Needs to be an arg in order to function properly
  444. # on inplace operations to be carried (e.g. compile)
  445. self_attn_cache,
  446. **kwargs,
  447. )
  448. hidden_states = residual + self_attn_output
  449. residual = hidden_states
  450. normed_states = self.pre_ca_norm(hidden_states)
  451. cross_states, _ = self.cross_attention(
  452. normed_states,
  453. encoder_hidden_states,
  454. attention_mask=encoder_attention_mask,
  455. past_key_values=past_key_values,
  456. **kwargs,
  457. )
  458. hidden_states = residual + cross_states
  459. residual = hidden_states
  460. normed_states = self.pre_mlp_norm(hidden_states)
  461. mlp_out = self.mlp(normed_states)
  462. hidden_states = residual + mlp_out
  463. return hidden_states
  464. class DiaDecoder(DiaPreTrainedModel):
  465. """Transformer Decoder Stack using DenseGeneral."""
  466. _can_record_outputs = {
  467. "hidden_states": DiaDecoderLayer,
  468. "attentions": [DiaSelfAttention, DiaCrossAttention],
  469. }
  470. def __init__(self, config: DiaDecoderConfig):
  471. super().__init__(config)
  472. self.num_channels = config.num_channels
  473. self.vocab_size = config.vocab_size
  474. self.embeddings = DiaMultiChannelEmbedding(config)
  475. self.layers = nn.ModuleList(
  476. [DiaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  477. )
  478. self.norm = DiaRMSNorm(config.hidden_size, eps=config.norm_eps)
  479. self.rotary_emb = DiaRotaryEmbedding(config=config)
  480. self.post_init()
  481. @merge_with_config_defaults
  482. @capture_outputs
  483. @auto_docstring
  484. def forward(
  485. self,
  486. input_ids: torch.Tensor,
  487. position_ids: torch.LongTensor | None = None,
  488. attention_mask: torch.Tensor | None = None,
  489. encoder_hidden_states: torch.FloatTensor | None = None,
  490. encoder_attention_mask: torch.LongTensor | None = None,
  491. past_key_values: EncoderDecoderCache | None = None,
  492. **kwargs: Unpack[TransformersKwargs],
  493. ) -> BaseModelOutputWithPastAndCrossAttentions | tuple:
  494. r"""
  495. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`):
  496. The original `decoder_input_ids` in 3D shape to facilitate more efficient computations.
  497. [What are input IDs?](../glossary#input-ids)
  498. """
  499. batch_size, seq_length = input_ids.size()[:-1]
  500. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  501. if position_ids is None:
  502. position_ids = torch.arange(seq_length, device=input_ids.device) + past_key_values_length
  503. position_ids = position_ids.unsqueeze(0)
  504. # RoPE
  505. hidden_states = self.embeddings(input_ids)
  506. if attention_mask is None and not is_torchdynamo_compiling():
  507. # required mask seq length can be calculated via length of past cache
  508. mask_seq_length = past_key_values_length + seq_length
  509. attention_mask = torch.ones(batch_size, mask_seq_length, device=input_ids.device)
  510. attention_mask = create_causal_mask(
  511. config=self.config,
  512. inputs_embeds=hidden_states,
  513. attention_mask=attention_mask,
  514. past_key_values=past_key_values,
  515. )
  516. encoder_attention_mask = create_bidirectional_mask(
  517. config=self.config,
  518. inputs_embeds=hidden_states,
  519. attention_mask=encoder_attention_mask,
  520. encoder_hidden_states=encoder_hidden_states,
  521. )
  522. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  523. for layer in self.layers:
  524. hidden_states = layer(
  525. hidden_states,
  526. # Needs to be an arg in order to function properly
  527. # on inplace operations to be carried (e.g. compile)
  528. position_embeddings,
  529. attention_mask,
  530. encoder_hidden_states,
  531. encoder_attention_mask=encoder_attention_mask,
  532. past_key_values=past_key_values,
  533. position_ids=position_ids,
  534. **kwargs,
  535. )
  536. hidden_states = self.norm(hidden_states)
  537. return BaseModelOutputWithPastAndCrossAttentions(
  538. last_hidden_state=hidden_states,
  539. past_key_values=past_key_values,
  540. )
  541. @auto_docstring(
  542. custom_intro="""
  543. The bare Dia model outputting raw hidden-states without any specific head on top.
  544. """
  545. )
  546. class DiaModel(DiaPreTrainedModel):
  547. def __init__(self, config: DiaConfig):
  548. super().__init__(config)
  549. self.config = config
  550. self.encoder = DiaEncoder(config.encoder_config)
  551. self.decoder = DiaDecoder(config.decoder_config)
  552. self.post_init()
  553. @auto_docstring
  554. @can_return_tuple
  555. def forward(
  556. self,
  557. input_ids: torch.LongTensor | None = None,
  558. attention_mask: torch.LongTensor | None = None,
  559. decoder_input_ids: torch.LongTensor | None = None,
  560. decoder_position_ids: torch.LongTensor | None = None,
  561. decoder_attention_mask: torch.LongTensor | None = None,
  562. encoder_outputs: BaseModelOutput | tuple | None = None,
  563. past_key_values: EncoderDecoderCache | None = None,
  564. use_cache: bool | None = None,
  565. **kwargs,
  566. ) -> tuple | Seq2SeqModelOutput:
  567. r"""
  568. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  569. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  570. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  571. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  572. tened audio logits which are used to calculate the loss.
  573. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  574. Dia to calculate embeddings and subsequent steps more efficiently.
  575. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  576. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  577. [`DiaProcessor.__call__`] for more details.
  578. [What are decoder input IDs?](../glossary#decoder-input-ids)
  579. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  580. Indices of positions of each input sequence tokens in the position embeddings.
  581. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  582. [What are position IDs?](../glossary#position-ids)
  583. """
  584. if input_ids is None and encoder_outputs is None:
  585. raise ValueError(
  586. "You should either provide text ids or the cached text encodings. Neither has been found."
  587. )
  588. if self.is_gradient_checkpointing and self.training:
  589. if use_cache:
  590. logger.warning_once(
  591. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  592. )
  593. use_cache = False
  594. if use_cache and past_key_values is None:
  595. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  596. if encoder_outputs is None:
  597. encoder_outputs = self.encoder(
  598. input_ids=input_ids,
  599. attention_mask=attention_mask,
  600. **kwargs,
  601. )
  602. # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput
  603. elif not isinstance(encoder_outputs, BaseModelOutput):
  604. encoder_outputs = BaseModelOutput(
  605. last_hidden_state=encoder_outputs[0],
  606. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  607. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  608. )
  609. # On default we initialize the decoder with bos tokens if nothing has been provided
  610. bsz, seq_len, channels = (encoder_outputs[0].shape[0], -1, self.config.decoder_config.num_channels)
  611. if decoder_input_ids is None:
  612. decoder_input_ids = torch.full(
  613. size=(bsz, 1, channels), fill_value=self.config.decoder_config.bos_token_id, device=self.device
  614. )
  615. # Ensure 3D
  616. if decoder_input_ids.ndim == 2:
  617. decoder_input_ids = decoder_input_ids.reshape(bsz, channels, seq_len).transpose(1, 2)
  618. decoder_outputs = self.decoder(
  619. input_ids=decoder_input_ids,
  620. position_ids=decoder_position_ids,
  621. attention_mask=decoder_attention_mask,
  622. encoder_hidden_states=encoder_outputs[0],
  623. encoder_attention_mask=attention_mask,
  624. past_key_values=past_key_values,
  625. use_cache=use_cache,
  626. **kwargs,
  627. )
  628. return Seq2SeqModelOutput(
  629. last_hidden_state=decoder_outputs.last_hidden_state,
  630. past_key_values=decoder_outputs.past_key_values,
  631. decoder_hidden_states=decoder_outputs.hidden_states,
  632. decoder_attentions=decoder_outputs.attentions,
  633. cross_attentions=decoder_outputs.cross_attentions,
  634. encoder_last_hidden_state=encoder_outputs[0],
  635. encoder_hidden_states=encoder_outputs.hidden_states,
  636. encoder_attentions=encoder_outputs.attentions,
  637. )
  638. @auto_docstring(
  639. custom_intro="""
  640. The Dia model consisting of a (byte) text encoder and audio decoder with a prediction head on top.
  641. """
  642. )
  643. class DiaForConditionalGeneration(DiaPreTrainedModel, DiaGenerationMixin):
  644. base_model_prefix = "model"
  645. output_modalities = ("audio",)
  646. def __init__(self, config: DiaConfig):
  647. super().__init__(config)
  648. self.config = config
  649. self.model = DiaModel(config)
  650. self.num_channels = config.decoder_config.num_channels
  651. self.vocab_size = config.decoder_config.vocab_size
  652. self.logits_dense = nn.Linear(
  653. config.decoder_config.hidden_size, (self.num_channels * self.vocab_size), bias=False
  654. )
  655. self.loss_type = "ForMaskedLM"
  656. # Initialize weights and apply final processing
  657. self.post_init()
  658. @auto_docstring
  659. @can_return_tuple
  660. def forward(
  661. self,
  662. input_ids: torch.LongTensor | None = None,
  663. attention_mask: torch.LongTensor | None = None,
  664. decoder_input_ids: torch.LongTensor | None = None,
  665. decoder_position_ids: torch.LongTensor | None = None,
  666. decoder_attention_mask: torch.LongTensor | None = None,
  667. encoder_outputs: BaseModelOutput | tuple | None = None,
  668. past_key_values: EncoderDecoderCache | None = None,
  669. use_cache: bool | None = None,
  670. labels: torch.LongTensor | None = None,
  671. **kwargs,
  672. ) -> tuple | Seq2SeqLMOutput:
  673. r"""
  674. decoder_input_ids (`torch.LongTensor` of shape `(batch_size * num_codebooks, target_sequence_length)
  675. or (batch_size, target_sequence_length, num_codebooks)`, *optional*):
  676. 1. (batch_size * num_codebooks, target_sequence_length): corresponds to the general use case where
  677. the audio input codebooks are flattened into the batch dimension. This also aligns with the flat-
  678. tened audio logits which are used to calculate the loss.
  679. 2. (batch_size, sequence_length, num_codebooks): corresponds to the internally used shape of
  680. Dia to calculate embeddings and subsequent steps more efficiently.
  681. If no `decoder_input_ids` are provided, it will create a tensor of `bos_token_id` with shape
  682. `(batch_size, 1, num_codebooks)`. Indices can be obtained using the [`DiaProcessor`]. See
  683. [`DiaProcessor.__call__`] for more details.
  684. [What are decoder input IDs?](../glossary#decoder-input-ids)
  685. decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
  686. Indices of positions of each input sequence tokens in the position embeddings.
  687. Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`.
  688. [What are position IDs?](../glossary#position-ids)
  689. labels (`torch.LongTensor` of shape `(batch_size * num_codebooks,)`, *optional*):
  690. Labels for computing the masked language modeling loss. Indices should either be in
  691. `[0, ..., config.decoder_config.vocab_size - 1]` or -100. Tokens with indices set to `-100`
  692. are ignored (masked).
  693. """
  694. outputs = self.model(
  695. input_ids=input_ids,
  696. attention_mask=attention_mask,
  697. decoder_input_ids=decoder_input_ids,
  698. decoder_position_ids=decoder_position_ids,
  699. decoder_attention_mask=decoder_attention_mask,
  700. encoder_outputs=encoder_outputs,
  701. past_key_values=past_key_values,
  702. use_cache=use_cache,
  703. **kwargs,
  704. )
  705. last_hidden_state = outputs[0]
  706. batch_size = last_hidden_state.shape[0]
  707. # 3D <-> 2D makes it necessary to prioritize channel dim
  708. audio_logits = (
  709. self.logits_dense(last_hidden_state)
  710. .view((batch_size, -1, self.num_channels, self.vocab_size))
  711. .transpose(1, 2)
  712. .contiguous()
  713. .view(batch_size * self.num_channels, -1, self.vocab_size)
  714. )
  715. loss = None
  716. if labels is not None:
  717. loss = self.loss_function(logits=audio_logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
  718. return Seq2SeqLMOutput(
  719. loss=loss,
  720. logits=audio_logits,
  721. past_key_values=outputs.past_key_values,
  722. decoder_hidden_states=outputs.decoder_hidden_states,
  723. decoder_attentions=outputs.decoder_attentions,
  724. cross_attentions=outputs.cross_attentions,
  725. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  726. encoder_hidden_states=outputs.encoder_hidden_states,
  727. encoder_attentions=outputs.encoder_attentions,
  728. )
  729. __all__ = ["DiaModel", "DiaPreTrainedModel", "DiaForConditionalGeneration"]