modeling_blt.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/blt/modular_blt.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_blt.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from collections.abc import Callable
  21. from typing import Optional
  22. import torch
  23. import torch.distributions
  24. import torch.nn as nn
  25. import torch.nn.functional as F
  26. from ... import initialization as init
  27. from ...activations import ACT2FN
  28. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  29. from ...generation import GenerationMixin
  30. from ...masking_utils import create_causal_mask
  31. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  32. from ...modeling_layers import GradientCheckpointingLayer
  33. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  34. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  35. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  36. from ...processing_utils import Unpack
  37. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
  38. from ...utils.deprecation import deprecate_kwarg
  39. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  40. from ...utils.output_capturing import OutputRecorder, capture_outputs
  41. from .configuration_blt import (
  42. BltConfig,
  43. BltGlobalTransformerConfig,
  44. BltLocalDecoderConfig,
  45. BltLocalEncoderConfig,
  46. BltPatcherConfig,
  47. )
  48. class BltMLP(nn.Module):
  49. def __init__(self, config):
  50. super().__init__()
  51. self.config = config
  52. self.hidden_size = config.hidden_size
  53. self.intermediate_size = config.intermediate_size
  54. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  55. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  56. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  57. # Ignore copy
  58. self.act_fn = ACT2FN[config.hidden_act]
  59. def forward(self, x):
  60. down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  61. return down_proj
  62. class BltRMSNorm(nn.Module):
  63. def __init__(self, hidden_size, eps: float = 1e-6) -> None:
  64. """
  65. BltRMSNorm is equivalent to T5LayerNorm
  66. """
  67. super().__init__()
  68. self.weight = nn.Parameter(torch.ones(hidden_size))
  69. self.variance_epsilon = eps
  70. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  71. input_dtype = hidden_states.dtype
  72. hidden_states = hidden_states.to(torch.float32)
  73. variance = hidden_states.pow(2).mean(-1, keepdim=True)
  74. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  75. return self.weight * hidden_states.to(input_dtype)
  76. def extra_repr(self):
  77. return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
  78. class BltRotaryEmbedding(nn.Module):
  79. inv_freq: torch.Tensor # fix linting for `register_buffer`
  80. def __init__(self, config: BltConfig, device=None):
  81. super().__init__()
  82. self.max_seq_len_cached = config.max_position_embeddings
  83. self.original_max_seq_len = config.max_position_embeddings
  84. self.config = config
  85. self.rope_type = self.config.rope_parameters["rope_type"]
  86. rope_init_fn: Callable = self.compute_default_rope_parameters
  87. if self.rope_type != "default":
  88. rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  89. inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
  90. self.register_buffer("inv_freq", inv_freq, persistent=False)
  91. self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
  92. @staticmethod
  93. def compute_default_rope_parameters(
  94. config: BltConfig | None = None,
  95. device: Optional["torch.device"] = None,
  96. seq_len: int | None = None,
  97. ) -> tuple["torch.Tensor", float]:
  98. """
  99. Computes the inverse frequencies according to the original RoPE implementation
  100. Args:
  101. config ([`~transformers.PreTrainedConfig`]):
  102. The model configuration.
  103. device (`torch.device`):
  104. The device to use for initialization of the inverse frequencies.
  105. seq_len (`int`, *optional*):
  106. The current sequence length. Unused for this type of RoPE.
  107. Returns:
  108. Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
  109. post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
  110. """
  111. base = config.rope_parameters["rope_theta"]
  112. dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
  113. attention_factor = 1.0 # Unused in this type of RoPE
  114. # Compute the inverse frequencies
  115. inv_freq = 1.0 / (
  116. base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
  117. )
  118. return inv_freq, attention_factor
  119. @torch.no_grad()
  120. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  121. def forward(self, x, position_ids):
  122. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  123. position_ids_expanded = position_ids[:, None, :].float()
  124. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  125. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  126. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  127. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  128. cos = emb.cos() * self.attention_scaling
  129. sin = emb.sin() * self.attention_scaling
  130. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  131. # Modified from transformers.models.llama.modeling_llama.LlamaDecoderLayer
  132. class BltTransformerLayer(GradientCheckpointingLayer):
  133. def __init__(self, config, layer_idx: int):
  134. super().__init__()
  135. self.hidden_size = config.hidden_size
  136. self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx)
  137. self.mlp = BltMLP(config)
  138. self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  139. self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  140. self.layer_idx = layer_idx
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. cross_attention_states: torch.Tensor | None = None,
  145. cross_attention_mask: torch.Tensor | None = None,
  146. attention_mask: torch.Tensor | None = None,
  147. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor] | None = None,
  148. position_ids: torch.LongTensor | None = None,
  149. past_key_values: Cache | None = None,
  150. use_cache: bool | None = False,
  151. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  152. **kwargs: Unpack[FlashAttentionKwargs],
  153. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  154. """
  155. Args:
  156. hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
  157. attention_mask (`torch.FloatTensor`, *optional*):
  158. attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
  159. query_sequence_length, key_sequence_length)` if default attention is used.
  160. use_cache (`bool`, *optional*):
  161. If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
  162. (see `past_key_values`).
  163. past_key_values (`Cache`, *optional*): cached past key and value projection states
  164. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
  165. Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
  166. with `head_dim` being the embedding dimension of each attention head.
  167. kwargs (`dict`, *optional*):
  168. Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
  169. into the model
  170. """
  171. residual = hidden_states
  172. hidden_states = self.input_layernorm(hidden_states)
  173. # Self Attention
  174. hidden_states, self_attn_weights = self.self_attn(
  175. hidden_states=hidden_states,
  176. attention_mask=attention_mask,
  177. position_ids=position_ids,
  178. past_key_values=past_key_values,
  179. use_cache=use_cache,
  180. position_embeddings=position_embeddings,
  181. **kwargs,
  182. )
  183. hidden_states = residual + hidden_states
  184. # Fully Connected
  185. residual = hidden_states
  186. hidden_states = self.post_attention_layernorm(hidden_states)
  187. hidden_states = self.mlp(hidden_states)
  188. hidden_states = residual + hidden_states
  189. return hidden_states
  190. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  191. """
  192. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  193. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  194. """
  195. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  196. if n_rep == 1:
  197. return hidden_states
  198. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  199. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  200. def eager_attention_forward(
  201. module: nn.Module,
  202. query: torch.Tensor,
  203. key: torch.Tensor,
  204. value: torch.Tensor,
  205. attention_mask: torch.Tensor | None,
  206. scaling: float,
  207. dropout: float = 0.0,
  208. **kwargs: Unpack[TransformersKwargs],
  209. ):
  210. key_states = repeat_kv(key, module.num_key_value_groups)
  211. value_states = repeat_kv(value, module.num_key_value_groups)
  212. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  213. if attention_mask is not None:
  214. attn_weights = attn_weights + attention_mask
  215. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  216. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  217. attn_output = torch.matmul(attn_weights, value_states)
  218. attn_output = attn_output.transpose(1, 2).contiguous()
  219. return attn_output, attn_weights
  220. def rotate_half(x):
  221. # Split and rotate. Note that this function is different from e.g. Llama.
  222. x1 = x[..., ::2]
  223. x2 = x[..., 1::2]
  224. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  225. return rot_x
  226. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  227. """Applies Rotary Position Embedding to the query and key tensors.
  228. Args:
  229. q (`torch.Tensor`): The query tensor.
  230. k (`torch.Tensor`): The key tensor.
  231. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  232. sin (`torch.Tensor`): The sine part of the rotary embedding.
  233. unsqueeze_dim (`int`, *optional*, defaults to 1):
  234. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  235. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  236. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  237. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  238. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  239. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  240. Returns:
  241. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  242. """
  243. cos = cos.unsqueeze(unsqueeze_dim)
  244. sin = sin.unsqueeze(unsqueeze_dim)
  245. q_embed = (q * cos) + (rotate_half(q) * sin)
  246. k_embed = (k * cos) + (rotate_half(k) * sin)
  247. return q_embed, k_embed
  248. class BltSelfAttention(nn.Module):
  249. def __init__(self, config: BltConfig, layer_idx: int):
  250. super().__init__()
  251. self.config = config
  252. self.num_heads = config.num_attention_heads
  253. self.dropout = config.dropout
  254. self.hidden_size = config.hidden_size
  255. self.num_key_value_heads = config.num_key_value_heads
  256. self.head_dim = config.hidden_size // self.num_heads
  257. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  258. self.scaling = self.head_dim**-0.5
  259. self.layer_idx = layer_idx
  260. self.is_causal = True
  261. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  262. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  263. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  264. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. attention_mask: torch.Tensor,
  269. position_embeddings: torch.Tensor,
  270. past_key_values=None,
  271. **kwargs,
  272. ):
  273. bsz, q_len, _ = hidden_states.size()
  274. query_states = self.q_proj(hidden_states)
  275. key_states = self.k_proj(hidden_states)
  276. value_states = self.v_proj(hidden_states)
  277. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  278. key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  279. value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  280. cos, sin = position_embeddings
  281. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  282. if past_key_values is not None:
  283. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  284. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  285. self.config._attn_implementation, eager_attention_forward
  286. )
  287. attn_output, attn_weights = attention_interface(
  288. self,
  289. query_states,
  290. key_states,
  291. value_states,
  292. attention_mask,
  293. dropout=0.0 if not self.training else self.dropout,
  294. scaling=self.scaling,
  295. **kwargs,
  296. )
  297. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  298. attn_output = self.o_proj(attn_output)
  299. return attn_output, attn_weights
  300. class BltCrossAttention(nn.Module):
  301. """Cross-attention module for Blt, following transformers style"""
  302. def __init__(self, config: BltConfig, layer_idx: int, hidden_size: int | None = None):
  303. super().__init__()
  304. self.config = config
  305. self.num_heads = self.config.num_attention_heads
  306. self.num_key_value_heads = self.config.num_key_value_heads
  307. self.dropout = config.dropout
  308. self.hidden_size = config.hidden_size
  309. self.head_dim = config.hidden_size // self.num_heads
  310. self.layer_idx = layer_idx
  311. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  312. self.scaling = self.head_dim**-0.5
  313. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
  314. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  315. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
  316. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  317. self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  318. self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  319. self.is_causal = False
  320. def forward(
  321. self,
  322. hidden_states: torch.Tensor,
  323. cross_attention_states: torch.Tensor | None = None,
  324. attention_mask: torch.Tensor | None = None,
  325. **kwargs: Unpack[TransformersKwargs],
  326. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  327. """Input shape: Batch x Time x Channel"""
  328. bsz, q_len, _ = hidden_states.size()
  329. query_states = self.q_norm(hidden_states)
  330. query_states = self.q_proj(query_states)
  331. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  332. cross_attention_states = self.k_norm(cross_attention_states)
  333. key_states = self.k_proj(cross_attention_states)
  334. value_states = self.v_proj(cross_attention_states)
  335. key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  336. value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  337. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  338. self.config._attn_implementation, eager_attention_forward
  339. )
  340. attn_output, attn_weights = attention_interface(
  341. self,
  342. query_states,
  343. key_states,
  344. value_states,
  345. attention_mask,
  346. dropout=0.0 if not self.training else self.dropout,
  347. scaling=self.scaling,
  348. **kwargs,
  349. )
  350. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  351. attn_output = self.o_proj(attn_output)
  352. attn_output = attn_output + hidden_states
  353. return attn_output, attn_weights
  354. @auto_docstring
  355. class BltPreTrainedModel(PreTrainedModel):
  356. config: BltConfig
  357. base_model_prefix = "model"
  358. input_modalities = ("image", "text")
  359. supports_gradient_checkpointing = True
  360. _no_split_modules = ["BltTransformerLayer"]
  361. _can_compile_fullgraph = False # static cache cannot have different shapes for each layer
  362. _supports_sdpa = True
  363. _supports_flash_attn = False
  364. _supports_flex_attn = False
  365. _supports_attention_backend = False
  366. _can_record_outputs = {
  367. "hidden_states": OutputRecorder(BltTransformerLayer, index=0),
  368. "attentions": OutputRecorder(BltSelfAttention, index=1),
  369. }
  370. @torch.no_grad()
  371. def _init_weights(self, module):
  372. """
  373. Initialize BLT weights following the original ByteLatentTransformer:
  374. - Most weights are drawn from a truncated normal.
  375. - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
  376. - Norm layers are set to weight = 1, bias = 0.
  377. """
  378. class_name = module.__class__.__name__
  379. # Norms: RMSNorm / LayerNorm
  380. if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
  381. if getattr(module, "weight", None) is not None:
  382. init.ones_(module.weight)
  383. if getattr(module, "bias", None) is not None:
  384. init.zeros_(module.bias)
  385. return
  386. # Embeddings (encoder / patcher / hash embeddings)
  387. if isinstance(module, nn.Embedding):
  388. hidden_size = getattr(self.config, "hidden_size", None)
  389. if hidden_size is None and hasattr(self.config, "encoder_config"):
  390. hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
  391. if hidden_size is None:
  392. hidden_size = module.embedding_dim
  393. std = hidden_size**-0.5
  394. init.trunc_normal_(
  395. module.weight,
  396. mean=0.0,
  397. std=std,
  398. a=-3 * std,
  399. b=3 * std,
  400. )
  401. if module.padding_idx is not None:
  402. init.zeros_(module.weight[module.padding_idx])
  403. return
  404. # Self-attention / cross-attention projections
  405. if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
  406. "MllamaTextSelfAttention",
  407. "MllamaTextCrossAttention",
  408. ):
  409. dim = getattr(self.config, "hidden_size", None)
  410. if dim is None and hasattr(module, "hidden_size"):
  411. dim = module.hidden_size
  412. if dim is None:
  413. for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
  414. proj = getattr(module, name, None)
  415. if proj is not None and hasattr(proj, "weight"):
  416. dim = proj.weight.shape[-1]
  417. break
  418. if dim is None:
  419. return
  420. std = dim**-0.5
  421. # Input projections (q, k, v)
  422. for proj_name in ("q_proj", "k_proj", "v_proj"):
  423. proj = getattr(module, proj_name, None)
  424. if proj is not None and hasattr(proj, "weight"):
  425. init.trunc_normal_(
  426. proj.weight,
  427. mean=0.0,
  428. std=std,
  429. a=-3 * std,
  430. b=3 * std,
  431. )
  432. if getattr(proj, "bias", None) is not None:
  433. init.zeros_(proj.bias)
  434. # Output projection: o_proj or dense
  435. o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
  436. if o_proj is not None and hasattr(o_proj, "weight"):
  437. init.trunc_normal_(
  438. o_proj.weight,
  439. mean=0.0,
  440. std=std,
  441. a=-3 * std,
  442. b=3 * std,
  443. )
  444. if getattr(o_proj, "bias", None) is not None:
  445. init.zeros_(o_proj.bias)
  446. return
  447. # MLP / FFN blocks
  448. if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
  449. hidden_size = getattr(self.config, "hidden_size", None)
  450. if hidden_size is None and hasattr(self.config, "decoder_config"):
  451. hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
  452. if hidden_size is None and hasattr(self.config, "encoder_config"):
  453. hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
  454. # Input-side std
  455. in_std = None
  456. if hidden_size is not None:
  457. in_std = hidden_size**-0.5
  458. gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
  459. up_proj = getattr(module, "up_proj", None)
  460. down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))
  461. # gate / input projections
  462. for proj in (gate_proj, up_proj):
  463. if proj is not None and hasattr(proj, "weight"):
  464. std = in_std or (proj.weight.shape[1] ** -0.5)
  465. init.trunc_normal_(
  466. proj.weight,
  467. mean=0.0,
  468. std=std,
  469. a=-3 * std,
  470. b=3 * std,
  471. )
  472. if getattr(proj, "bias", None) is not None:
  473. init.zeros_(proj.bias)
  474. # output/ down projections
  475. if down_proj is not None and hasattr(down_proj, "weight"):
  476. hidden_dim = down_proj.weight.shape[1]
  477. out_std = hidden_dim**-0.5
  478. init.trunc_normal_(
  479. down_proj.weight,
  480. mean=0.0,
  481. std=out_std,
  482. a=-3 * out_std,
  483. b=3 * out_std,
  484. )
  485. if getattr(down_proj, "bias", None) is not None:
  486. init.zeros_(down_proj.bias)
  487. return
  488. # Generic Linear layers (projections, lm_head, etc.)
  489. if isinstance(module, nn.Linear):
  490. fan_in = module.in_features
  491. std = fan_in**-0.5
  492. init.trunc_normal_(
  493. module.weight,
  494. mean=0.0,
  495. std=std,
  496. a=-3 * std,
  497. b=3 * std,
  498. )
  499. if module.bias is not None:
  500. init.zeros_(module.bias)
  501. return
  502. if isinstance(module, BltRotaryEmbedding):
  503. rope_fn = (
  504. ROPE_INIT_FUNCTIONS[module.rope_type]
  505. if module.rope_type != "default"
  506. else module.compute_default_rope_parameters
  507. )
  508. buffer_value, _ = rope_fn(module.config)
  509. init.copy_(module.inv_freq, buffer_value)
  510. init.copy_(module.original_inv_freq, buffer_value)
  511. class BltLocalEncoder(BltPreTrainedModel):
  512. config: BltLocalEncoderConfig
  513. _can_record_outputs = {
  514. "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"),
  515. }
  516. def __init__(self, config: BltLocalEncoderConfig):
  517. super().__init__(config)
  518. self.gradient_checkpointing = False
  519. self.config = config
  520. self.layers = nn.ModuleList(
  521. [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  522. )
  523. self.rotary_emb = BltRotaryEmbedding(config=config)
  524. self.patch_embedding_projection = nn.Linear(
  525. in_features=config.hidden_size,
  526. out_features=config.hidden_size * config.cross_attn_k,
  527. bias=False,
  528. )
  529. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  530. self.cross_attn_layers = nn.ModuleList()
  531. layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
  532. for layer_idx in range(layers_to_add):
  533. self.cross_attn_layers.append(
  534. BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
  535. )
  536. self.post_init()
  537. def forward(
  538. self,
  539. input_ids: torch.LongTensor | None = None,
  540. inputs_embeds: torch.Tensor | None = None,
  541. patch_embeds: torch.Tensor | None = None,
  542. attention_mask: torch.Tensor | None = None,
  543. position_ids: torch.LongTensor | None = None,
  544. past_key_values: Cache | None = None,
  545. encoder_attention_mask: torch.Tensor | None = None,
  546. num_patches: int | None = None,
  547. patch_ids: torch.Tensor | None = None,
  548. **kwargs: Unpack[TransformersKwargs],
  549. ):
  550. if inputs_embeds is None:
  551. inputs_embeds = self.embed_tokens(input_ids)
  552. batch_size = inputs_embeds.shape[0]
  553. hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training)
  554. if position_ids is None:
  555. position_ids = (
  556. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  557. )
  558. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  559. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  560. for idx, layer in enumerate(self.layers):
  561. hidden_states = layer(
  562. hidden_states,
  563. position_embeddings=position_embeddings,
  564. attention_mask=attention_mask,
  565. past_key_values=past_key_values,
  566. **kwargs,
  567. )
  568. if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
  569. patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids)
  570. patch_embeds = self.patch_embedding_projection(patch_embeds)
  571. patch_embeds = patch_embeds.reshape(
  572. batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
  573. )
  574. layer_idx = idx if self.config.cross_attn_all_layers else 0
  575. cross_attention_output, _ = self.cross_attn_layers[layer_idx](
  576. hidden_states=patch_embeds,
  577. cross_attention_states=hidden_states,
  578. attention_mask=encoder_attention_mask,
  579. **kwargs,
  580. )
  581. patch_embeds = patch_embeds + cross_attention_output
  582. encoder_cross_states = patch_embeds
  583. return hidden_states, encoder_cross_states
  584. def patch_reduce(self, hidden_states, max_num_patches, patch_ids):
  585. """
  586. Reduce variable length patches to single embedding per patch
  587. Note: this works with variable number of patches for different sequences in the batch
  588. It handles variable length patches by assuming that patch_lengths will be 0 for any
  589. extra patches on the *right*. Since there can be a variable number of patches
  590. this function also return the number of patches for each sequence in the batch.
  591. Any embeddings on the right that are not allocated to a patch
  592. (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
  593. will be sent to a dummy patch, which is trimmed before returning.
  594. """
  595. batch_size = hidden_states.shape[0]
  596. embedding_dim = hidden_states.shape[-1]
  597. patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
  598. reduced_embeddings = torch.zeros(
  599. (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device
  600. )
  601. reduced_embeddings = reduced_embeddings.scatter_reduce(
  602. src=hidden_states,
  603. dim=1,
  604. index=patch_ids,
  605. reduce="amax",
  606. include_self=False,
  607. )
  608. reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
  609. return reduced_embeddings
  610. class BltLocalDecoder(BltPreTrainedModel):
  611. config: BltLocalDecoderConfig
  612. def __init__(self, config: BltLocalDecoderConfig):
  613. super().__init__(config)
  614. self.gradient_checkpointing = False
  615. self.config = config
  616. self.cross_attn_decoder = True
  617. self.layers = nn.ModuleList(
  618. [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  619. )
  620. self.rotary_emb = BltRotaryEmbedding(config=config)
  621. self.patch_embedding_projection = nn.Linear(
  622. in_features=config.hidden_size_global,
  623. out_features=config.hidden_size * config.cross_attn_k,
  624. bias=False,
  625. )
  626. self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  627. self.cross_attn_layers = nn.ModuleList()
  628. layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
  629. for layer_idx in range(layers_to_add):
  630. self.cross_attn_layers.append(
  631. BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
  632. )
  633. self.post_init()
  634. def forward(
  635. self,
  636. input_ids: torch.LongTensor | None = None,
  637. inputs_embeds: torch.Tensor | None = None,
  638. patch_embeds: torch.Tensor | None = None,
  639. attention_mask: torch.Tensor | None = None,
  640. position_ids: torch.LongTensor | None = None,
  641. past_key_values: Cache | None = None,
  642. encoder_attention_mask: torch.Tensor | None = None,
  643. **kwargs: Unpack[TransformersKwargs],
  644. ):
  645. batch_size = inputs_embeds.shape[0]
  646. hidden_states = inputs_embeds
  647. patch_embeds = self.patch_embedding_projection(patch_embeds)
  648. patch_embeds = patch_embeds.reshape(
  649. batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
  650. )
  651. if patch_embeds is not None and not self.cross_attn_decoder:
  652. hidden_states = hidden_states + patch_embeds
  653. if position_ids is None:
  654. position_ids = (
  655. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  656. )
  657. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  658. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  659. for i, layer in enumerate(self.layers):
  660. if i == 0 or self.config.cross_attn_all_layers:
  661. cross_attention_output, _ = self.cross_attn_layers[i](
  662. hidden_states=hidden_states,
  663. cross_attention_states=patch_embeds,
  664. attention_mask=encoder_attention_mask,
  665. **kwargs,
  666. )
  667. hidden_states = hidden_states + cross_attention_output
  668. hidden_states = layer(
  669. hidden_states,
  670. position_embeddings=position_embeddings,
  671. attention_mask=attention_mask,
  672. past_key_values=past_key_values,
  673. **kwargs,
  674. )
  675. logits = self.norm(hidden_states)
  676. return logits
  677. class BltGlobalTransformer(BltPreTrainedModel):
  678. config: BltGlobalTransformerConfig
  679. _can_record_outputs = {
  680. "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"),
  681. }
  682. def __init__(self, config: BltGlobalTransformerConfig):
  683. super().__init__(config)
  684. self.config = config
  685. self.layers = nn.ModuleList()
  686. for layer_idx in range(config.num_hidden_layers):
  687. self.layers.append(BltTransformerLayer(config, layer_idx))
  688. self.rotary_emb = BltRotaryEmbedding(config=config)
  689. # Create token embedding projection (use nn.Identity() when no projection needed)
  690. if getattr(config, "encoder_cross_output_size", None) is not None:
  691. self.token_embedding_projection = nn.Linear(
  692. config.encoder_cross_output_size, config.hidden_size, bias=False
  693. )
  694. else:
  695. self.token_embedding_projection = nn.Identity()
  696. self.post_init()
  697. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  698. def forward(
  699. self,
  700. inputs_embeds: torch.Tensor,
  701. attention_mask: torch.Tensor | None = None,
  702. position_ids: torch.LongTensor | None = None,
  703. past_key_values: Cache | None = None,
  704. **kwargs: Unpack[TransformersKwargs],
  705. ):
  706. batch_size, seq_len, _ = inputs_embeds.shape
  707. hidden_states = self.token_embedding_projection(inputs_embeds)
  708. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  709. if position_ids is None:
  710. position_ids = (
  711. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  712. )
  713. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  714. for i, layer in enumerate(self.layers):
  715. hidden_states = layer(
  716. hidden_states,
  717. position_embeddings=position_embeddings,
  718. attention_mask=attention_mask,
  719. past_key_values=past_key_values,
  720. **kwargs,
  721. )
  722. return hidden_states
  723. def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int | None) -> torch.Tensor:
  724. """
  725. Splits patch lengths into smaller segments if they exceed `max_patch_length`.
  726. Pads the result to uniform length across the batch.
  727. Args:
  728. patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
  729. max_patch_length (int, optional): Maximum allowed length per patch.
  730. Returns:
  731. torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
  732. """
  733. if max_patch_length is None:
  734. return patch_lengths
  735. batch_size = patch_lengths.size(0)
  736. processed = []
  737. for seq in patch_lengths:
  738. splits = []
  739. for length in seq[seq > 0]:
  740. length = length.item()
  741. full_chunks, remainder = divmod(length, max_patch_length)
  742. splits.extend([max_patch_length] * full_chunks)
  743. if remainder:
  744. splits.append(remainder)
  745. processed.append(splits)
  746. # Find max length to pad to
  747. max_len = max(len(splits) for splits in processed)
  748. padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
  749. for i, splits in enumerate(processed):
  750. if splits:
  751. padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
  752. # Trim zero columns
  753. if (padded != 0).any(dim=0).sum() < padded.shape[1]:
  754. last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
  755. padded = padded[:, :last_nonzero]
  756. return padded
  757. class BltPatcher(BltPreTrainedModel):
  758. config: BltPatcherConfig
  759. def __init__(self, config: BltPatcherConfig):
  760. super().__init__(config)
  761. self.rotary_emb = BltRotaryEmbedding(config=self.config)
  762. self.layers = nn.ModuleList()
  763. for layer_idx in range(self.config.num_hidden_layers):
  764. self.layers.append(BltTransformerLayer(self.config, layer_idx))
  765. self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
  766. self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  767. self.lm_head = nn.Linear(
  768. self.config.hidden_size,
  769. self.config.vocab_size,
  770. bias=False,
  771. )
  772. self.post_init()
  773. def forward(
  774. self,
  775. input_ids: torch.LongTensor | None = None,
  776. attention_mask: torch.Tensor | None = None,
  777. position_ids: torch.LongTensor | None = None,
  778. past_key_values: Cache | None = None,
  779. inputs_embeds: torch.FloatTensor | None = None,
  780. use_cache: bool | None = None,
  781. patch_size: int | None = None,
  782. threshold: float | None = None,
  783. max_patch_length: int | None = None,
  784. **kwargs: Unpack[TransformersKwargs],
  785. ):
  786. if (input_ids is None) ^ (inputs_embeds is not None):
  787. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  788. if inputs_embeds is None:
  789. inputs_embeds = self.embed_tokens(input_ids)
  790. if use_cache and past_key_values is None:
  791. past_key_values = DynamicCache(config=self.config)
  792. if position_ids is None:
  793. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  794. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  795. position_ids = position_ids.unsqueeze(0)
  796. causal_mask = create_causal_mask(
  797. config=self.config,
  798. inputs_embeds=inputs_embeds,
  799. attention_mask=attention_mask,
  800. past_key_values=past_key_values,
  801. position_ids=position_ids,
  802. )
  803. hidden_states = inputs_embeds
  804. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  805. for layer in self.layers:
  806. hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)
  807. logits = self.lm_head(self.norm(hidden_states))
  808. prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
  809. batch_size, sequence_length = inputs_embeds.shape[:2]
  810. if patch_size is not None:
  811. patch_lengths = self.patch_lengths_from_entropies(
  812. entropies=prediction_entropies,
  813. sequence_length=sequence_length,
  814. patch_size=patch_size,
  815. threshold=threshold,
  816. )
  817. else:
  818. patch_lengths = torch.ones(
  819. (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
  820. )
  821. patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
  822. return prediction_entropies, patch_lengths, logits
  823. @staticmethod
  824. def patch_lengths_from_entropies(
  825. entropies,
  826. sequence_length,
  827. patch_size=None,
  828. threshold=None,
  829. ):
  830. """
  831. Computes patch lengths from token entropies.
  832. Depending on whether a threshold is provided, the function uses either:
  833. - Thresholding the entropy values (when `threshold` is set).
  834. """
  835. batch_size = entropies.shape[0]
  836. # Always include token 0 and 1 as starting tokens
  837. init_tokens = (
  838. torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
  839. )
  840. offset = init_tokens.shape[1]
  841. # Ignore first token entropy (BOS)
  842. entropies = entropies[:, 1:]
  843. # Threshold the entropy values to define patch start points
  844. patch_mask = entropies > threshold
  845. seq_len = patch_mask.shape[1]
  846. # Create patch IDs (token indices), and add a sentinel to ensure alignment
  847. token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
  848. sentinel = torch.full_like(token_indices, seq_len)
  849. padded_indices = torch.cat([token_indices, sentinel], dim=1)
  850. # Pad mask with inverse to align sentinel correctly
  851. padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
  852. # Select indices where mask is True
  853. patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
  854. max_valid_patches = patch_mask.sum(dim=1).max()
  855. patch_starts = patch_starts[:, :max_valid_patches]
  856. # Offset patch starts to account for the two initial tokens
  857. patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
  858. # Compute patch end positions by shifting start positions
  859. last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
  860. patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
  861. patch_lengths = patch_ends - patch_start_ids + 1
  862. return patch_lengths
  863. def rolling_polynomial_hash(token_tensor, prime: int = 1000000007):
  864. """
  865. A polynomial rolling hash algorithm that converts sequences
  866. of tokens into hash values. The hash is computed as:
  867. hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n)
  868. The rolling hash allows the model to efficiently
  869. identify and encode recurring byte-level patterns in the input text.
  870. Args:
  871. token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash
  872. prime (int): Prime number used as the base for the polynomial hash.
  873. Returns:
  874. torch.Tensor: Hash values of shape [batch_size, seq_len] where each value
  875. represents the hash of the corresponding token group
  876. Example:
  877. >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]])
  878. >>> hashes = rolling_polynomial_hash(tokens, prime=31)
  879. >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2
  880. >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2
  881. """
  882. prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device)
  883. powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
  884. prime_powers = prime_tensor**powers
  885. return torch.sum(token_tensor * prime_powers, dim=-1)
  886. def byte_group_hash_function(
  887. token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000
  888. ):
  889. """Hash token groups and map to range [0, max_hash]."""
  890. with torch.no_grad():
  891. batch_size, seq_len = token_ids.shape
  892. # Add padding for sliding window
  893. padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
  894. padded_tokens = torch.cat([padding, token_ids], dim=1)
  895. # Create sliding windows and compute hashes
  896. windows = padded_tokens.unfold(1, group_size, 1)
  897. hashes = rolling_polynomial_hash(windows, prime)
  898. hash_values = hashes % max_hash
  899. return hash_values
  900. def compute_hash_embeddings(
  901. local_encoder_tokens: torch.Tensor,
  902. local_encoder,
  903. encoder_hash_tok_embedding: nn.Embedding,
  904. encoder_hash_byte_group_nb_functions: int,
  905. encoder_hash_byte_group_size: list,
  906. encoder_hash_byte_group_vocab: int,
  907. ) -> torch.Tensor:
  908. """Compute token embeddings enhanced with hash-based embeddings."""
  909. # Available primes for hash functions
  910. primes = [
  911. 1000000007,
  912. 5915587277,
  913. 1500450271,
  914. 3267000013,
  915. 5754853343,
  916. 4093082899,
  917. 9576890767,
  918. 3628273133,
  919. 2860486313,
  920. 5463458053,
  921. 3367900313,
  922. ]
  923. embeddings = local_encoder.embed_tokens(local_encoder_tokens)
  924. embedding_idx = 0
  925. for func_nb in range(encoder_hash_byte_group_nb_functions):
  926. prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes
  927. for group_size in encoder_hash_byte_group_size:
  928. hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
  929. # Apply offset to get the correct slice of the fused embedding
  930. offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
  931. embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
  932. embedding_idx += 1
  933. return embeddings
  934. def _prepare_patch_cross_attention_mask(
  935. patch_ids: torch.Tensor,
  936. num_patches: int,
  937. sequence_length: int,
  938. patches_as_queries: bool = False,
  939. cross_attn_k: int = 1,
  940. dtype: torch.dtype = torch.float32,
  941. ) -> tuple[torch.Tensor, torch.Tensor]:
  942. """
  943. Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
  944. This function creates masks that control which patches can attend to which other patches,
  945. with support for query/key role swapping and cross-attention multipliers.
  946. Args:
  947. patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
  948. num_patches (int): Total number of patches.
  949. sequence_length (int): Length of the sequence.
  950. patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
  951. cross_attn_k (int): Cross-attention multiplier for repeating patches.
  952. dtype (torch.dtype): Data type for the output mask.
  953. Returns:
  954. Tuple[torch.Tensor, torch.Tensor]:
  955. - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
  956. """
  957. batch_size, seq_len = patch_ids.shape
  958. device = patch_ids.device
  959. # Determine query and key lengths based on configuration
  960. if patches_as_queries:
  961. q_len = num_patches * cross_attn_k
  962. kv_len = sequence_length
  963. # Create patch-to-sequence mapping
  964. q_patch_ids = (
  965. torch.arange(num_patches, device=device)
  966. .unsqueeze(0)
  967. .unsqueeze(-1)
  968. .expand(batch_size, num_patches, seq_len)
  969. )
  970. kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
  971. else:
  972. q_len = sequence_length
  973. kv_len = num_patches * cross_attn_k
  974. # Create sequence-to-patch mapping
  975. q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
  976. kv_patch_ids = (
  977. torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches)
  978. )
  979. # Create base attention mask - boolean mask where True means "should attend"
  980. # Exact patch matching
  981. cross_attention_mask = q_patch_ids == kv_patch_ids
  982. # Handle cross_attn_k multiplier by repeating along appropriate dimension
  983. repeat_dim = 1 if patches_as_queries else -1
  984. cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
  985. # Validate dimensions
  986. expected_shape = (batch_size, q_len, kv_len)
  987. if cross_attention_mask.shape != expected_shape:
  988. raise ValueError(
  989. f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}"
  990. )
  991. # Reshape so it can be used by attn module - add head dimension
  992. cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
  993. # Invert the mask (following mllama pattern exactly)
  994. # True -> 0.0 (attend), False -> 1.0 (will become -inf)
  995. inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype)
  996. cross_attention_mask = inverted_cross_attn_mask.masked_fill(
  997. inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
  998. )
  999. return cross_attention_mask
  1000. class BltModel(BltPreTrainedModel):
  1001. def __init__(self, config: BltConfig):
  1002. super().__init__(config)
  1003. self.gradient_checkpointing = False
  1004. self.config = config
  1005. self.local_encoder = BltLocalEncoder(config.encoder_config)
  1006. self.global_transformer = BltGlobalTransformer(config.global_config)
  1007. self.local_decoder = BltLocalDecoder(config.decoder_config)
  1008. num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size)
  1009. total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings
  1010. self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size)
  1011. if self.config.patch_in_forward:
  1012. self.patcher = BltPatcher(config.patcher_config)
  1013. self.patcher.eval()
  1014. for param in self.patcher.parameters():
  1015. param.requires_grad = False
  1016. else:
  1017. self.patcher = None
  1018. self.post_init()
  1019. @merge_with_config_defaults
  1020. @capture_outputs
  1021. def forward(
  1022. self,
  1023. input_ids: torch.LongTensor | None = None,
  1024. patch_lengths: torch.Tensor | None = None,
  1025. attention_mask: torch.Tensor | None = None,
  1026. position_ids: torch.LongTensor | None = None,
  1027. past_key_values: Cache | None = None,
  1028. inputs_embeds: torch.FloatTensor | None = None,
  1029. use_cache: bool | None = None,
  1030. **kwargs: Unpack[TransformersKwargs],
  1031. ) -> tuple | BaseModelOutputWithPast:
  1032. if (input_ids is None) ^ (inputs_embeds is not None):
  1033. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1034. if use_cache:
  1035. if past_key_values is None:
  1036. past_key_values = EncoderDecoderCache(
  1037. DynamicCache(config=self.config), DynamicCache(config=self.config)
  1038. )
  1039. elif not isinstance(past_key_values, EncoderDecoderCache):
  1040. # BLT uses an encoder-decoder cache even though it is not en encoder-decoder model. Create a cross-cache
  1041. # if not yet created by the user
  1042. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
  1043. # Extract input embeddings as early as possible
  1044. if inputs_embeds is not None:
  1045. encoder_embeds = inputs_embeds
  1046. batch_size, sequence_length, _ = inputs_embeds.shape
  1047. else:
  1048. batch_size, sequence_length = input_ids.shape
  1049. encoder_embeds = compute_hash_embeddings(
  1050. input_ids,
  1051. self.local_encoder,
  1052. self.encoder_hash_tok_embedding,
  1053. self.config.encoder_hash_byte_group_nb_functions,
  1054. self.config.encoder_hash_byte_group_size,
  1055. self.config.encoder_hash_byte_group_vocab,
  1056. )
  1057. if patch_lengths is None:
  1058. if self.config.patching_mode == "entropy" and self.patcher is not None:
  1059. if input_ids is None:
  1060. raise ValueError("input_ids is required for entropy-based patching")
  1061. _, patch_lengths, _ = self.patcher(
  1062. input_ids,
  1063. patch_size=self.config.patch_size,
  1064. threshold=self.config.patching_threshold,
  1065. max_patch_length=self.config.max_patch_length,
  1066. patching_batch_size=self.config.patching_batch_size,
  1067. device=input_ids.device,
  1068. )
  1069. else:
  1070. device = input_ids.device if input_ids is not None else inputs_embeds.device
  1071. dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype
  1072. patch_lengths = process_patch_lengths(
  1073. torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device),
  1074. self.config.max_patch_length,
  1075. )
  1076. patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
  1077. if position_ids is None:
  1078. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1079. position_ids = torch.arange(encoder_embeds.shape[1], device=encoder_embeds.device) + past_seen_tokens
  1080. position_ids = position_ids.unsqueeze(0)
  1081. causal_mask = create_causal_mask(
  1082. config=self.config,
  1083. inputs_embeds=encoder_embeds,
  1084. attention_mask=attention_mask,
  1085. past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
  1086. position_ids=position_ids,
  1087. )
  1088. cross_attn_mask_enc = _prepare_patch_cross_attention_mask(
  1089. patch_ids=patch_ids,
  1090. num_patches=patch_lengths.shape[1],
  1091. sequence_length=sequence_length,
  1092. patches_as_queries=True,
  1093. cross_attn_k=self.config.cross_attn_k,
  1094. dtype=encoder_embeds.dtype,
  1095. )
  1096. encoder_hidden_states, encoder_cross_states = self.local_encoder(
  1097. input_ids=input_ids,
  1098. inputs_embeds=encoder_embeds,
  1099. attention_mask=causal_mask,
  1100. position_ids=position_ids,
  1101. encoder_attention_mask=cross_attn_mask_enc,
  1102. num_patches=patch_lengths.shape[1],
  1103. patch_ids=patch_ids,
  1104. past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
  1105. **kwargs,
  1106. )
  1107. encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
  1108. global_position_ids = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device)
  1109. global_position_ids = global_position_ids.unsqueeze(0)
  1110. global_causal_mask = create_causal_mask(
  1111. config=self.config,
  1112. inputs_embeds=encoder_cross_states,
  1113. attention_mask=None,
  1114. past_key_values=None,
  1115. position_ids=None,
  1116. )
  1117. global_hidden_states = self.global_transformer(
  1118. inputs_embeds=encoder_cross_states,
  1119. attention_mask=global_causal_mask,
  1120. position_ids=global_position_ids,
  1121. **kwargs,
  1122. )
  1123. decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
  1124. cross_attn_mask_dec = _prepare_patch_cross_attention_mask(
  1125. patch_ids=decoder_patch_ids,
  1126. num_patches=patch_lengths.shape[1],
  1127. sequence_length=sequence_length,
  1128. patches_as_queries=False,
  1129. cross_attn_k=self.config.cross_attn_k,
  1130. dtype=encoder_embeds.dtype,
  1131. )
  1132. output = self.local_decoder(
  1133. input_ids=input_ids,
  1134. inputs_embeds=encoder_hidden_states,
  1135. patch_embeds=global_hidden_states,
  1136. attention_mask=causal_mask,
  1137. position_ids=position_ids,
  1138. past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
  1139. encoder_attention_mask=cross_attn_mask_dec,
  1140. **kwargs,
  1141. )
  1142. return BaseModelOutputWithPast(
  1143. last_hidden_state=output,
  1144. past_key_values=past_key_values,
  1145. )
  1146. def get_input_embeddings(self):
  1147. return self.local_encoder.embed_tokens
  1148. def set_input_embeddings(self, value):
  1149. self.local_encoder.embed_tokens = value
  1150. def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
  1151. batch_size = patch_lengths.shape[0]
  1152. patch_starts = torch.cat(
  1153. [
  1154. torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
  1155. patch_lengths.cumsum(dim=-1)[:, :-1],
  1156. ],
  1157. dim=-1,
  1158. )
  1159. token_positions = torch.arange(seq_len, device=patch_lengths.device)
  1160. return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
  1161. @auto_docstring(
  1162. custom_intro="""
  1163. The Blt Text Model with a language modeling head on top.
  1164. """
  1165. )
  1166. class BltForCausalLM(BltPreTrainedModel, GenerationMixin):
  1167. config: BltConfig
  1168. _can_compile_fullgraph = False
  1169. base_model_prefix = "model"
  1170. _tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"}
  1171. def __init__(self, config: BltConfig):
  1172. super().__init__(config)
  1173. self.text_config = config.get_text_config()
  1174. self.vocab_size = config.vocab_size
  1175. self.model = BltModel(config)
  1176. self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
  1177. self.post_init()
  1178. @can_return_tuple
  1179. @auto_docstring
  1180. def forward(
  1181. self,
  1182. input_ids: torch.LongTensor | None = None,
  1183. attention_mask: torch.Tensor | None = None,
  1184. position_ids: torch.LongTensor | None = None,
  1185. cross_attention_states: torch.LongTensor | None = None, # Keep for compatibility
  1186. cross_attention_mask: torch.LongTensor | None = None,
  1187. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor] | None = None,
  1188. past_key_values: Cache | None = None,
  1189. inputs_embeds: torch.FloatTensor | None = None,
  1190. labels: torch.LongTensor | None = None,
  1191. use_cache: bool | None = None,
  1192. logits_to_keep: int | torch.Tensor = 0,
  1193. **kwargs: Unpack[TransformersKwargs],
  1194. ) -> tuple | CausalLMOutputWithPast:
  1195. r"""
  1196. cross_attention_states (`torch.FloatTensor`, *optional*):
  1197. Output of the vision model, used for cross-attention. This tensor contains the processed image features that
  1198. the language model will attend to.
  1199. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
  1200. Cross-attention mask to control the interaction between text tokens and image tiles.
  1201. This 4D tensor defines which image tiles each text token should attend to.
  1202. For each text token (in seq_length):
  1203. - 1 indicates the token **should attend** to the corresponding image tile
  1204. - 0 indicates the token **should not attend** to the corresponding image tile
  1205. full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*):
  1206. A tuple containing two tensors that mask out rows in the cross-attention mechanism:
  1207. - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1.
  1208. A value of 0 indicates that the corresponding text token's entire row in the cross-attention
  1209. matrix should be masked out (all image tokens ignored).
  1210. - The second tensor has the same shape and is used internally to apply the masking during
  1211. the forward pass of cross-attention layers.
  1212. This mask is derived from the cross_attention_mask and is used to handle cases where a text token
  1213. should not attend to any image token.
  1214. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1215. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  1216. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  1217. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  1218. Example:
  1219. ```python
  1220. >>> from transformers import AutoTokenizer, BltForCausalLM
  1221. >>> model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf")
  1222. >>> tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
  1223. >>> prompt = "If I had to write a haiku, it would be:"
  1224. >>> inputs = tokenizer(prompt, return_tensors="pt")
  1225. >>> # Generate
  1226. >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
  1227. >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  1228. >>> print(result)
  1229. If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
  1230. I love the idea of snowflakes gently falling, each one
  1231. ```
  1232. """
  1233. # Call parent forward but exclude cross_attention_states from model call
  1234. outputs = self.model(
  1235. input_ids=input_ids,
  1236. attention_mask=attention_mask,
  1237. position_ids=position_ids,
  1238. cross_attention_mask=cross_attention_mask,
  1239. full_text_row_masked_out_mask=full_text_row_masked_out_mask,
  1240. past_key_values=past_key_values,
  1241. inputs_embeds=inputs_embeds,
  1242. use_cache=use_cache,
  1243. **kwargs,
  1244. )
  1245. hidden_states = outputs.last_hidden_state
  1246. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1247. logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
  1248. loss = None
  1249. if labels is not None:
  1250. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1251. return CausalLMOutputWithPast(
  1252. loss=loss,
  1253. logits=logits,
  1254. past_key_values=outputs.past_key_values,
  1255. hidden_states=outputs.hidden_states,
  1256. attentions=outputs.attentions,
  1257. )
  1258. __all__ = ["BltPreTrainedModel", "BltModel", "BltPatcher", "BltForCausalLM"]