modular_blt.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195
  1. # Copyright 2025 HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Blt modular model, inheriting from Mllama where appropriate."""
  15. from collections.abc import Callable
  16. import torch
  17. import torch.distributions
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from ... import initialization as init
  21. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  22. from ...generation import GenerationMixin
  23. from ...masking_utils import create_causal_mask
  24. from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
  25. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
  26. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
  27. from ...processing_utils import Unpack
  28. from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
  29. from ...utils.deprecation import deprecate_kwarg
  30. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  31. from ...utils.output_capturing import OutputRecorder, capture_outputs
  32. from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401
  33. from ..llama.modeling_llama import LlamaRotaryEmbedding
  34. from ..mllama.modeling_mllama import (
  35. MllamaPreTrainedModel,
  36. MllamaSelfAttentionDecoderLayer,
  37. MllamaTextCrossAttention,
  38. MllamaTextMLP,
  39. MllamaTextRMSNorm,
  40. MllamaTextSelfAttention,
  41. eager_attention_forward,
  42. )
  43. from .configuration_blt import (
  44. BltConfig,
  45. BltGlobalTransformerConfig,
  46. BltLocalDecoderConfig,
  47. BltLocalEncoderConfig,
  48. BltPatcherConfig,
  49. )
  50. logger = logging.get_logger(__name__)
  51. def rolling_polynomial_hash(token_tensor, prime: int = 1000000007):
  52. """
  53. A polynomial rolling hash algorithm that converts sequences
  54. of tokens into hash values. The hash is computed as:
  55. hash = (token_0 * prime^0 + token_1 * prime^1 + ... + token_n * prime^n)
  56. The rolling hash allows the model to efficiently
  57. identify and encode recurring byte-level patterns in the input text.
  58. Args:
  59. token_tensor (torch.Tensor): [batch_size, seq_len, group_size] containing token IDs to hash
  60. prime (int): Prime number used as the base for the polynomial hash.
  61. Returns:
  62. torch.Tensor: Hash values of shape [batch_size, seq_len] where each value
  63. represents the hash of the corresponding token group
  64. Example:
  65. >>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]])
  66. >>> hashes = rolling_polynomial_hash(tokens, prime=31)
  67. >>> # hash[0] = 1*31^0 + 2*31^1 + 3*31^2
  68. >>> # hash[1] = 4*31^0 + 5*31^1 + 6*31^2
  69. """
  70. prime_tensor = torch.tensor(prime, dtype=torch.int64, device=token_tensor.device)
  71. powers = torch.arange(token_tensor.shape[-1], device=token_tensor.device)
  72. prime_powers = prime_tensor**powers
  73. return torch.sum(token_tensor * prime_powers, dim=-1)
  74. def byte_group_hash_function(
  75. token_ids: torch.Tensor, group_size: int = 2, prime: int = 1000000007, max_hash: int = 30000
  76. ):
  77. """Hash token groups and map to range [0, max_hash]."""
  78. with torch.no_grad():
  79. batch_size, seq_len = token_ids.shape
  80. # Add padding for sliding window
  81. padding = torch.zeros(batch_size, group_size - 1, dtype=torch.int64, device=token_ids.device)
  82. padded_tokens = torch.cat([padding, token_ids], dim=1)
  83. # Create sliding windows and compute hashes
  84. windows = padded_tokens.unfold(1, group_size, 1)
  85. hashes = rolling_polynomial_hash(windows, prime)
  86. hash_values = hashes % max_hash
  87. return hash_values
  88. def compute_hash_embeddings(
  89. local_encoder_tokens: torch.Tensor,
  90. local_encoder,
  91. encoder_hash_tok_embedding: nn.Embedding,
  92. encoder_hash_byte_group_nb_functions: int,
  93. encoder_hash_byte_group_size: list,
  94. encoder_hash_byte_group_vocab: int,
  95. ) -> torch.Tensor:
  96. """Compute token embeddings enhanced with hash-based embeddings."""
  97. # Available primes for hash functions
  98. primes = [
  99. 1000000007,
  100. 5915587277,
  101. 1500450271,
  102. 3267000013,
  103. 5754853343,
  104. 4093082899,
  105. 9576890767,
  106. 3628273133,
  107. 2860486313,
  108. 5463458053,
  109. 3367900313,
  110. ]
  111. embeddings = local_encoder.embed_tokens(local_encoder_tokens)
  112. embedding_idx = 0
  113. for func_nb in range(encoder_hash_byte_group_nb_functions):
  114. prime = primes[func_nb % len(primes)] # Cycle through primes if more functions than primes
  115. for group_size in encoder_hash_byte_group_size:
  116. hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
  117. # Apply offset to get the correct slice of the fused embedding
  118. offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
  119. embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
  120. embedding_idx += 1
  121. return embeddings
  122. def _prepare_patch_cross_attention_mask(
  123. patch_ids: torch.Tensor,
  124. num_patches: int,
  125. sequence_length: int,
  126. patches_as_queries: bool = False,
  127. cross_attn_k: int = 1,
  128. dtype: torch.dtype = torch.float32,
  129. ) -> tuple[torch.Tensor, torch.Tensor]:
  130. """
  131. Prepare cross-attention mask for patch-based attention, following mllama's robust approach.
  132. This function creates masks that control which patches can attend to which other patches,
  133. with support for query/key role swapping and cross-attention multipliers.
  134. Args:
  135. patch_ids (torch.Tensor): Tensor of shape [batch_size, seq_len] containing patch ids.
  136. num_patches (int): Total number of patches.
  137. sequence_length (int): Length of the sequence.
  138. patches_as_queries (bool): If True, patches are used as queries, otherwise as keys.
  139. cross_attn_k (int): Cross-attention multiplier for repeating patches.
  140. dtype (torch.dtype): Data type for the output mask.
  141. Returns:
  142. Tuple[torch.Tensor, torch.Tensor]:
  143. - cross_attention_mask: 4D tensor [batch_size, 1, q_len, kv_len]
  144. """
  145. batch_size, seq_len = patch_ids.shape
  146. device = patch_ids.device
  147. # Determine query and key lengths based on configuration
  148. if patches_as_queries:
  149. q_len = num_patches * cross_attn_k
  150. kv_len = sequence_length
  151. # Create patch-to-sequence mapping
  152. q_patch_ids = (
  153. torch.arange(num_patches, device=device)
  154. .unsqueeze(0)
  155. .unsqueeze(-1)
  156. .expand(batch_size, num_patches, seq_len)
  157. )
  158. kv_patch_ids = patch_ids.unsqueeze(1).expand(batch_size, num_patches, seq_len)
  159. else:
  160. q_len = sequence_length
  161. kv_len = num_patches * cross_attn_k
  162. # Create sequence-to-patch mapping
  163. q_patch_ids = patch_ids.unsqueeze(-1).expand(batch_size, seq_len, num_patches)
  164. kv_patch_ids = (
  165. torch.arange(num_patches, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, seq_len, num_patches)
  166. )
  167. # Create base attention mask - boolean mask where True means "should attend"
  168. # Exact patch matching
  169. cross_attention_mask = q_patch_ids == kv_patch_ids
  170. # Handle cross_attn_k multiplier by repeating along appropriate dimension
  171. repeat_dim = 1 if patches_as_queries else -1
  172. cross_attention_mask = cross_attention_mask.repeat_interleave(cross_attn_k, dim=repeat_dim)
  173. # Validate dimensions
  174. expected_shape = (batch_size, q_len, kv_len)
  175. if cross_attention_mask.shape != expected_shape:
  176. raise ValueError(
  177. f"Cross attention mask shape {cross_attention_mask.shape} doesn't match expected {expected_shape}"
  178. )
  179. # Reshape so it can be used by attn module - add head dimension
  180. cross_attention_mask = cross_attention_mask.unsqueeze(1) # [batch_size, 1, q_len, kv_len]
  181. # Invert the mask (following mllama pattern exactly)
  182. # True -> 0.0 (attend), False -> 1.0 (will become -inf)
  183. inverted_cross_attn_mask = 1.0 - cross_attention_mask.to(dtype)
  184. cross_attention_mask = inverted_cross_attn_mask.masked_fill(
  185. inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min
  186. )
  187. return cross_attention_mask
  188. def process_patch_lengths(patch_lengths: torch.Tensor, max_patch_length: int | None) -> torch.Tensor:
  189. """
  190. Splits patch lengths into smaller segments if they exceed `max_patch_length`.
  191. Pads the result to uniform length across the batch.
  192. Args:
  193. patch_lengths (torch.Tensor): [batch_size, num_patches] tensor of patch lengths.
  194. max_patch_length (int, optional): Maximum allowed length per patch.
  195. Returns:
  196. torch.Tensor: [batch_size, max_len] tensor of split and padded patch lengths.
  197. """
  198. if max_patch_length is None:
  199. return patch_lengths
  200. batch_size = patch_lengths.size(0)
  201. processed = []
  202. for seq in patch_lengths:
  203. splits = []
  204. for length in seq[seq > 0]:
  205. length = length.item()
  206. full_chunks, remainder = divmod(length, max_patch_length)
  207. splits.extend([max_patch_length] * full_chunks)
  208. if remainder:
  209. splits.append(remainder)
  210. processed.append(splits)
  211. # Find max length to pad to
  212. max_len = max(len(splits) for splits in processed)
  213. padded = torch.zeros((batch_size, max_len), dtype=patch_lengths.dtype, device=patch_lengths.device)
  214. for i, splits in enumerate(processed):
  215. if splits:
  216. padded[i, : len(splits)] = torch.tensor(splits, dtype=patch_lengths.dtype, device=patch_lengths.device)
  217. # Trim zero columns
  218. if (padded != 0).any(dim=0).sum() < padded.shape[1]:
  219. last_nonzero = (padded != 0).any(dim=0).nonzero().max().item() + 1
  220. padded = padded[:, :last_nonzero]
  221. return padded
  222. class BltMLP(MllamaTextMLP):
  223. pass
  224. class BltRMSNorm(MllamaTextRMSNorm):
  225. pass
  226. class BltRotaryEmbedding(LlamaRotaryEmbedding):
  227. @torch.no_grad()
  228. @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
  229. def forward(self, x, position_ids):
  230. inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
  231. position_ids_expanded = position_ids[:, None, :].float()
  232. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  233. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  234. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
  235. emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
  236. cos = emb.cos() * self.attention_scaling
  237. sin = emb.sin() * self.attention_scaling
  238. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  239. class BltTransformerLayer(MllamaSelfAttentionDecoderLayer):
  240. def __init__(self, config, layer_idx: int):
  241. super().__init__()
  242. self.self_attn = BltSelfAttention(config=config, layer_idx=layer_idx)
  243. self.mlp = BltMLP(config)
  244. self.input_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  245. self.post_attention_layernorm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  246. class BltSelfAttention(MllamaTextSelfAttention):
  247. def __init__(self, config: BltConfig, layer_idx: int):
  248. super().__init__(config, layer_idx)
  249. class BltCrossAttention(MllamaTextCrossAttention):
  250. """Cross-attention module for Blt, following transformers style"""
  251. def __init__(self, config: BltConfig, layer_idx: int, hidden_size: int | None = None):
  252. super().__init__()
  253. self.is_causal = False
  254. self.q_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  255. self.k_norm = BltRMSNorm(self.hidden_size, eps=config.rms_norm_eps)
  256. def forward(
  257. self,
  258. hidden_states: torch.Tensor,
  259. cross_attention_states: torch.Tensor | None = None,
  260. attention_mask: torch.Tensor | None = None,
  261. **kwargs: Unpack[TransformersKwargs],
  262. ):
  263. bsz, q_len, _ = hidden_states.size()
  264. query_states = self.q_norm(hidden_states)
  265. query_states = self.q_proj(query_states)
  266. query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  267. cross_attention_states = self.k_norm(cross_attention_states)
  268. key_states = self.k_proj(cross_attention_states)
  269. value_states = self.v_proj(cross_attention_states)
  270. key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  271. value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
  272. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  273. self.config._attn_implementation, eager_attention_forward
  274. )
  275. attn_output, attn_weights = attention_interface(
  276. self,
  277. query_states,
  278. key_states,
  279. value_states,
  280. attention_mask,
  281. dropout=0.0 if not self.training else self.dropout,
  282. scaling=self.scaling,
  283. **kwargs,
  284. )
  285. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  286. attn_output = self.o_proj(attn_output)
  287. attn_output = attn_output + hidden_states
  288. return attn_output, attn_weights
  289. @auto_docstring
  290. class BltPreTrainedModel(MllamaPreTrainedModel):
  291. config: BltConfig
  292. _supports_attention_backend = False
  293. _supports_flash_attn = False
  294. _supports_flex_attn = False
  295. _no_split_modules = ["BltTransformerLayer"]
  296. _can_record_outputs = {
  297. "hidden_states": OutputRecorder(BltTransformerLayer, index=0),
  298. "attentions": OutputRecorder(BltSelfAttention, index=1),
  299. }
  300. # Weight initialization is adapted from:
  301. # - https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py
  302. # - https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py
  303. #
  304. # Both implementations use truncated normal initialization with std ~ 1 / sqrt(d_model)
  305. # (or 1 / sqrt(hidden_dim) for FFN outputs), and unit initialization for normalization layers.
  306. # We follow the same scheme here, but expressed in the Transformers APIs.
  307. @torch.no_grad()
  308. def _init_weights(self, module):
  309. """
  310. Initialize BLT weights following the original ByteLatentTransformer:
  311. - Most weights are drawn from a truncated normal.
  312. - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
  313. - Norm layers are set to weight = 1, bias = 0.
  314. """
  315. class_name = module.__class__.__name__
  316. # Norms: RMSNorm / LayerNorm
  317. if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
  318. if getattr(module, "weight", None) is not None:
  319. init.ones_(module.weight)
  320. if getattr(module, "bias", None) is not None:
  321. init.zeros_(module.bias)
  322. return
  323. # Embeddings (encoder / patcher / hash embeddings)
  324. if isinstance(module, nn.Embedding):
  325. hidden_size = getattr(self.config, "hidden_size", None)
  326. if hidden_size is None and hasattr(self.config, "encoder_config"):
  327. hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
  328. if hidden_size is None:
  329. hidden_size = module.embedding_dim
  330. std = hidden_size**-0.5
  331. init.trunc_normal_(
  332. module.weight,
  333. mean=0.0,
  334. std=std,
  335. a=-3 * std,
  336. b=3 * std,
  337. )
  338. if module.padding_idx is not None:
  339. init.zeros_(module.weight[module.padding_idx])
  340. return
  341. # Self-attention / cross-attention projections
  342. if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
  343. "MllamaTextSelfAttention",
  344. "MllamaTextCrossAttention",
  345. ):
  346. dim = getattr(self.config, "hidden_size", None)
  347. if dim is None and hasattr(module, "hidden_size"):
  348. dim = module.hidden_size
  349. if dim is None:
  350. for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
  351. proj = getattr(module, name, None)
  352. if proj is not None and hasattr(proj, "weight"):
  353. dim = proj.weight.shape[-1]
  354. break
  355. if dim is None:
  356. return
  357. std = dim**-0.5
  358. # Input projections (q, k, v)
  359. for proj_name in ("q_proj", "k_proj", "v_proj"):
  360. proj = getattr(module, proj_name, None)
  361. if proj is not None and hasattr(proj, "weight"):
  362. init.trunc_normal_(
  363. proj.weight,
  364. mean=0.0,
  365. std=std,
  366. a=-3 * std,
  367. b=3 * std,
  368. )
  369. if getattr(proj, "bias", None) is not None:
  370. init.zeros_(proj.bias)
  371. # Output projection: o_proj or dense
  372. o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
  373. if o_proj is not None and hasattr(o_proj, "weight"):
  374. init.trunc_normal_(
  375. o_proj.weight,
  376. mean=0.0,
  377. std=std,
  378. a=-3 * std,
  379. b=3 * std,
  380. )
  381. if getattr(o_proj, "bias", None) is not None:
  382. init.zeros_(o_proj.bias)
  383. return
  384. # MLP / FFN blocks
  385. if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
  386. hidden_size = getattr(self.config, "hidden_size", None)
  387. if hidden_size is None and hasattr(self.config, "decoder_config"):
  388. hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
  389. if hidden_size is None and hasattr(self.config, "encoder_config"):
  390. hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
  391. # Input-side std
  392. in_std = None
  393. if hidden_size is not None:
  394. in_std = hidden_size**-0.5
  395. gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
  396. up_proj = getattr(module, "up_proj", None)
  397. down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))
  398. # gate / input projections
  399. for proj in (gate_proj, up_proj):
  400. if proj is not None and hasattr(proj, "weight"):
  401. std = in_std or (proj.weight.shape[1] ** -0.5)
  402. init.trunc_normal_(
  403. proj.weight,
  404. mean=0.0,
  405. std=std,
  406. a=-3 * std,
  407. b=3 * std,
  408. )
  409. if getattr(proj, "bias", None) is not None:
  410. init.zeros_(proj.bias)
  411. # output/ down projections
  412. if down_proj is not None and hasattr(down_proj, "weight"):
  413. hidden_dim = down_proj.weight.shape[1]
  414. out_std = hidden_dim**-0.5
  415. init.trunc_normal_(
  416. down_proj.weight,
  417. mean=0.0,
  418. std=out_std,
  419. a=-3 * out_std,
  420. b=3 * out_std,
  421. )
  422. if getattr(down_proj, "bias", None) is not None:
  423. init.zeros_(down_proj.bias)
  424. return
  425. # Generic Linear layers (projections, lm_head, etc.)
  426. if isinstance(module, nn.Linear):
  427. fan_in = module.in_features
  428. std = fan_in**-0.5
  429. init.trunc_normal_(
  430. module.weight,
  431. mean=0.0,
  432. std=std,
  433. a=-3 * std,
  434. b=3 * std,
  435. )
  436. if module.bias is not None:
  437. init.zeros_(module.bias)
  438. return
  439. if isinstance(module, BltRotaryEmbedding):
  440. rope_fn = (
  441. ROPE_INIT_FUNCTIONS[module.rope_type]
  442. if module.rope_type != "default"
  443. else module.compute_default_rope_parameters
  444. )
  445. buffer_value, _ = rope_fn(module.config)
  446. init.copy_(module.inv_freq, buffer_value)
  447. init.copy_(module.original_inv_freq, buffer_value)
  448. class BltLocalEncoder(BltPreTrainedModel):
  449. config: BltLocalEncoderConfig
  450. _can_record_outputs = {
  451. "encoder_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_encoder"),
  452. }
  453. def __init__(self, config: BltLocalEncoderConfig):
  454. super().__init__(config)
  455. self.gradient_checkpointing = False
  456. self.config = config
  457. self.layers = nn.ModuleList(
  458. [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  459. )
  460. self.rotary_emb = BltRotaryEmbedding(config=config)
  461. self.patch_embedding_projection = nn.Linear(
  462. in_features=config.hidden_size,
  463. out_features=config.hidden_size * config.cross_attn_k,
  464. bias=False,
  465. )
  466. self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
  467. self.cross_attn_layers = nn.ModuleList()
  468. layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
  469. for layer_idx in range(layers_to_add):
  470. self.cross_attn_layers.append(
  471. BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
  472. )
  473. self.post_init()
  474. def forward(
  475. self,
  476. input_ids: torch.LongTensor | None = None,
  477. inputs_embeds: torch.Tensor | None = None,
  478. patch_embeds: torch.Tensor | None = None,
  479. attention_mask: torch.Tensor | None = None,
  480. position_ids: torch.LongTensor | None = None,
  481. past_key_values: Cache | None = None,
  482. encoder_attention_mask: torch.Tensor | None = None,
  483. num_patches: int | None = None,
  484. patch_ids: torch.Tensor | None = None,
  485. **kwargs: Unpack[TransformersKwargs],
  486. ):
  487. if inputs_embeds is None:
  488. inputs_embeds = self.embed_tokens(input_ids)
  489. batch_size = inputs_embeds.shape[0]
  490. hidden_states = F.dropout(inputs_embeds, p=self.config.dropout, training=self.training)
  491. if position_ids is None:
  492. position_ids = (
  493. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  494. )
  495. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  496. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  497. for idx, layer in enumerate(self.layers):
  498. hidden_states = layer(
  499. hidden_states,
  500. position_embeddings=position_embeddings,
  501. attention_mask=attention_mask,
  502. past_key_values=past_key_values,
  503. **kwargs,
  504. )
  505. if idx == len(self.layers) - 1 or self.config.cross_attn_all_layers:
  506. patch_embeds = self.patch_reduce(hidden_states, num_patches, patch_ids)
  507. patch_embeds = self.patch_embedding_projection(patch_embeds)
  508. patch_embeds = patch_embeds.reshape(
  509. batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
  510. )
  511. layer_idx = idx if self.config.cross_attn_all_layers else 0
  512. cross_attention_output, _ = self.cross_attn_layers[layer_idx](
  513. hidden_states=patch_embeds,
  514. cross_attention_states=hidden_states,
  515. attention_mask=encoder_attention_mask,
  516. **kwargs,
  517. )
  518. patch_embeds = patch_embeds + cross_attention_output
  519. encoder_cross_states = patch_embeds
  520. return hidden_states, encoder_cross_states
  521. def patch_reduce(self, hidden_states, max_num_patches, patch_ids):
  522. """
  523. Reduce variable length patches to single embedding per patch
  524. Note: this works with variable number of patches for different sequences in the batch
  525. It handles variable length patches by assuming that patch_lengths will be 0 for any
  526. extra patches on the *right*. Since there can be a variable number of patches
  527. this function also return the number of patches for each sequence in the batch.
  528. Any embeddings on the right that are not allocated to a patch
  529. (i.e. if the sum(patch_lengths[i]) < seq_len for any i)
  530. will be sent to a dummy patch, which is trimmed before returning.
  531. """
  532. batch_size = hidden_states.shape[0]
  533. embedding_dim = hidden_states.shape[-1]
  534. patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, hidden_states.shape[-1])
  535. reduced_embeddings = torch.zeros(
  536. (batch_size, max_num_patches, embedding_dim), dtype=hidden_states.dtype, device=hidden_states.device
  537. )
  538. reduced_embeddings = reduced_embeddings.scatter_reduce(
  539. src=hidden_states,
  540. dim=1,
  541. index=patch_ids,
  542. reduce="amax",
  543. include_self=False,
  544. )
  545. reduced_embeddings = reduced_embeddings[:, :max_num_patches, :]
  546. return reduced_embeddings
  547. class BltLocalDecoder(BltPreTrainedModel):
  548. config: BltLocalDecoderConfig
  549. def __init__(self, config: BltLocalDecoderConfig):
  550. super().__init__(config)
  551. self.gradient_checkpointing = False
  552. self.config = config
  553. self.cross_attn_decoder = True
  554. self.layers = nn.ModuleList(
  555. [BltTransformerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  556. )
  557. self.rotary_emb = BltRotaryEmbedding(config=config)
  558. self.patch_embedding_projection = nn.Linear(
  559. in_features=config.hidden_size_global,
  560. out_features=config.hidden_size * config.cross_attn_k,
  561. bias=False,
  562. )
  563. self.norm = BltRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  564. self.cross_attn_layers = nn.ModuleList()
  565. layers_to_add = config.num_hidden_layers if config.cross_attn_all_layers else 1
  566. for layer_idx in range(layers_to_add):
  567. self.cross_attn_layers.append(
  568. BltCrossAttention(config=config, layer_idx=layer_idx, hidden_size=config.hidden_size)
  569. )
  570. self.post_init()
  571. def forward(
  572. self,
  573. input_ids: torch.LongTensor | None = None,
  574. inputs_embeds: torch.Tensor | None = None,
  575. patch_embeds: torch.Tensor | None = None,
  576. attention_mask: torch.Tensor | None = None,
  577. position_ids: torch.LongTensor | None = None,
  578. past_key_values: Cache | None = None,
  579. encoder_attention_mask: torch.Tensor | None = None,
  580. **kwargs: Unpack[TransformersKwargs],
  581. ):
  582. batch_size = inputs_embeds.shape[0]
  583. hidden_states = inputs_embeds
  584. patch_embeds = self.patch_embedding_projection(patch_embeds)
  585. patch_embeds = patch_embeds.reshape(
  586. batch_size, patch_embeds.shape[1] * self.config.cross_attn_k, self.config.hidden_size
  587. )
  588. if patch_embeds is not None and not self.cross_attn_decoder:
  589. hidden_states = hidden_states + patch_embeds
  590. if position_ids is None:
  591. position_ids = (
  592. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  593. )
  594. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  595. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  596. for i, layer in enumerate(self.layers):
  597. if i == 0 or self.config.cross_attn_all_layers:
  598. cross_attention_output, _ = self.cross_attn_layers[i](
  599. hidden_states=hidden_states,
  600. cross_attention_states=patch_embeds,
  601. attention_mask=encoder_attention_mask,
  602. **kwargs,
  603. )
  604. hidden_states = hidden_states + cross_attention_output
  605. hidden_states = layer(
  606. hidden_states,
  607. position_embeddings=position_embeddings,
  608. attention_mask=attention_mask,
  609. past_key_values=past_key_values,
  610. **kwargs,
  611. )
  612. logits = self.norm(hidden_states)
  613. return logits
  614. class BltGlobalTransformer(BltPreTrainedModel):
  615. config: BltGlobalTransformerConfig
  616. _can_record_outputs = {
  617. "global_attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="global_transformer"),
  618. }
  619. def __init__(self, config: BltGlobalTransformerConfig):
  620. super().__init__(config)
  621. self.config = config
  622. self.layers = nn.ModuleList()
  623. for layer_idx in range(config.num_hidden_layers):
  624. self.layers.append(BltTransformerLayer(config, layer_idx))
  625. self.rotary_emb = BltRotaryEmbedding(config=config)
  626. # Create token embedding projection (use nn.Identity() when no projection needed)
  627. if getattr(config, "encoder_cross_output_size", None) is not None:
  628. self.token_embedding_projection = nn.Linear(
  629. config.encoder_cross_output_size, config.hidden_size, bias=False
  630. )
  631. else:
  632. self.token_embedding_projection = nn.Identity()
  633. self.post_init()
  634. @deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
  635. def forward(
  636. self,
  637. inputs_embeds: torch.Tensor,
  638. attention_mask: torch.Tensor | None = None,
  639. position_ids: torch.LongTensor | None = None,
  640. past_key_values: Cache | None = None,
  641. **kwargs: Unpack[TransformersKwargs],
  642. ):
  643. batch_size, seq_len, _ = inputs_embeds.shape
  644. hidden_states = self.token_embedding_projection(inputs_embeds)
  645. hidden_states = F.dropout(hidden_states, p=self.config.dropout, training=self.training)
  646. if position_ids is None:
  647. position_ids = (
  648. torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0).expand(batch_size, -1)
  649. )
  650. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  651. for i, layer in enumerate(self.layers):
  652. hidden_states = layer(
  653. hidden_states,
  654. position_embeddings=position_embeddings,
  655. attention_mask=attention_mask,
  656. past_key_values=past_key_values,
  657. **kwargs,
  658. )
  659. return hidden_states
  660. class BltPatcher(BltPreTrainedModel):
  661. config: BltPatcherConfig
  662. def __init__(self, config: BltPatcherConfig):
  663. super().__init__(config)
  664. self.rotary_emb = BltRotaryEmbedding(config=self.config)
  665. self.layers = nn.ModuleList()
  666. for layer_idx in range(self.config.num_hidden_layers):
  667. self.layers.append(BltTransformerLayer(self.config, layer_idx))
  668. self.embed_tokens = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
  669. self.norm = BltRMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps)
  670. self.lm_head = nn.Linear(
  671. self.config.hidden_size,
  672. self.config.vocab_size,
  673. bias=False,
  674. )
  675. self.post_init()
  676. def forward(
  677. self,
  678. input_ids: torch.LongTensor | None = None,
  679. attention_mask: torch.Tensor | None = None,
  680. position_ids: torch.LongTensor | None = None,
  681. past_key_values: Cache | None = None,
  682. inputs_embeds: torch.FloatTensor | None = None,
  683. use_cache: bool | None = None,
  684. patch_size: int | None = None,
  685. threshold: float | None = None,
  686. max_patch_length: int | None = None,
  687. **kwargs: Unpack[TransformersKwargs],
  688. ):
  689. if (input_ids is None) ^ (inputs_embeds is not None):
  690. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  691. if inputs_embeds is None:
  692. inputs_embeds = self.embed_tokens(input_ids)
  693. if use_cache and past_key_values is None:
  694. past_key_values = DynamicCache(config=self.config)
  695. if position_ids is None:
  696. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  697. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  698. position_ids = position_ids.unsqueeze(0)
  699. causal_mask = create_causal_mask(
  700. config=self.config,
  701. inputs_embeds=inputs_embeds,
  702. attention_mask=attention_mask,
  703. past_key_values=past_key_values,
  704. position_ids=position_ids,
  705. )
  706. hidden_states = inputs_embeds
  707. position_embeddings = self.rotary_emb(hidden_states, position_ids)
  708. for layer in self.layers:
  709. hidden_states = layer(hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask)
  710. logits = self.lm_head(self.norm(hidden_states))
  711. prediction_entropies = torch.distributions.Categorical(logits=logits).entropy()
  712. batch_size, sequence_length = inputs_embeds.shape[:2]
  713. if patch_size is not None:
  714. patch_lengths = self.patch_lengths_from_entropies(
  715. entropies=prediction_entropies,
  716. sequence_length=sequence_length,
  717. patch_size=patch_size,
  718. threshold=threshold,
  719. )
  720. else:
  721. patch_lengths = torch.ones(
  722. (batch_size, sequence_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device
  723. )
  724. patch_lengths = process_patch_lengths(patch_lengths, max_patch_length)
  725. return prediction_entropies, patch_lengths, logits
  726. @staticmethod
  727. def patch_lengths_from_entropies(
  728. entropies,
  729. sequence_length,
  730. patch_size=None,
  731. threshold=None,
  732. ):
  733. """
  734. Computes patch lengths from token entropies.
  735. Depending on whether a threshold is provided, the function uses either:
  736. - Thresholding the entropy values (when `threshold` is set).
  737. """
  738. batch_size = entropies.shape[0]
  739. # Always include token 0 and 1 as starting tokens
  740. init_tokens = (
  741. torch.tensor([0, 1], dtype=torch.long, device=entropies.device).unsqueeze(0).repeat(batch_size, 1)
  742. )
  743. offset = init_tokens.shape[1]
  744. # Ignore first token entropy (BOS)
  745. entropies = entropies[:, 1:]
  746. # Threshold the entropy values to define patch start points
  747. patch_mask = entropies > threshold
  748. seq_len = patch_mask.shape[1]
  749. # Create patch IDs (token indices), and add a sentinel to ensure alignment
  750. token_indices = torch.arange(seq_len, device=entropies.device).unsqueeze(0).expand(batch_size, -1)
  751. sentinel = torch.full_like(token_indices, seq_len)
  752. padded_indices = torch.cat([token_indices, sentinel], dim=1)
  753. # Pad mask with inverse to align sentinel correctly
  754. padded_mask = torch.cat([patch_mask, ~patch_mask], dim=1)
  755. # Select indices where mask is True
  756. patch_starts = padded_indices[padded_mask].reshape(batch_size, seq_len)
  757. max_valid_patches = patch_mask.sum(dim=1).max()
  758. patch_starts = patch_starts[:, :max_valid_patches]
  759. # Offset patch starts to account for the two initial tokens
  760. patch_start_ids = torch.cat((init_tokens, patch_starts + offset), dim=1)
  761. # Compute patch end positions by shifting start positions
  762. last_token = torch.full_like(patch_start_ids[:, :1], sequence_length - 1)
  763. patch_ends = torch.cat((patch_start_ids[:, 1:] - 1, last_token), dim=1)
  764. patch_lengths = patch_ends - patch_start_ids + 1
  765. return patch_lengths
  766. class BltModel(BltPreTrainedModel):
  767. def __init__(self, config: BltConfig):
  768. super().__init__(config)
  769. self.gradient_checkpointing = False
  770. self.config = config
  771. self.local_encoder = BltLocalEncoder(config.encoder_config)
  772. self.global_transformer = BltGlobalTransformer(config.global_config)
  773. self.local_decoder = BltLocalDecoder(config.decoder_config)
  774. num_embeddings = config.encoder_hash_byte_group_nb_functions * len(config.encoder_hash_byte_group_size)
  775. total_vocab_size = config.encoder_hash_byte_group_vocab * num_embeddings
  776. self.encoder_hash_tok_embedding = nn.Embedding(total_vocab_size, config.encoder_config.hidden_size)
  777. if self.config.patch_in_forward:
  778. self.patcher = BltPatcher(config.patcher_config)
  779. self.patcher.eval()
  780. for param in self.patcher.parameters():
  781. param.requires_grad = False
  782. else:
  783. self.patcher = None
  784. self.post_init()
  785. @merge_with_config_defaults
  786. @capture_outputs
  787. def forward(
  788. self,
  789. input_ids: torch.LongTensor | None = None,
  790. patch_lengths: torch.Tensor | None = None,
  791. attention_mask: torch.Tensor | None = None,
  792. position_ids: torch.LongTensor | None = None,
  793. past_key_values: Cache | None = None,
  794. inputs_embeds: torch.FloatTensor | None = None,
  795. use_cache: bool | None = None,
  796. **kwargs: Unpack[TransformersKwargs],
  797. ) -> tuple | BaseModelOutputWithPast:
  798. if (input_ids is None) ^ (inputs_embeds is not None):
  799. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  800. if use_cache:
  801. if past_key_values is None:
  802. past_key_values = EncoderDecoderCache(
  803. DynamicCache(config=self.config), DynamicCache(config=self.config)
  804. )
  805. elif not isinstance(past_key_values, EncoderDecoderCache):
  806. # BLT uses an encoder-decoder cache even though it is not en encoder-decoder model. Create a cross-cache
  807. # if not yet created by the user
  808. past_key_values = EncoderDecoderCache(past_key_values, DynamicCache(config=self.config))
  809. # Extract input embeddings as early as possible
  810. if inputs_embeds is not None:
  811. encoder_embeds = inputs_embeds
  812. batch_size, sequence_length, _ = inputs_embeds.shape
  813. else:
  814. batch_size, sequence_length = input_ids.shape
  815. encoder_embeds = compute_hash_embeddings(
  816. input_ids,
  817. self.local_encoder,
  818. self.encoder_hash_tok_embedding,
  819. self.config.encoder_hash_byte_group_nb_functions,
  820. self.config.encoder_hash_byte_group_size,
  821. self.config.encoder_hash_byte_group_vocab,
  822. )
  823. if patch_lengths is None:
  824. if self.config.patching_mode == "entropy" and self.patcher is not None:
  825. if input_ids is None:
  826. raise ValueError("input_ids is required for entropy-based patching")
  827. _, patch_lengths, _ = self.patcher(
  828. input_ids,
  829. patch_size=self.config.patch_size,
  830. threshold=self.config.patching_threshold,
  831. max_patch_length=self.config.max_patch_length,
  832. patching_batch_size=self.config.patching_batch_size,
  833. device=input_ids.device,
  834. )
  835. else:
  836. device = input_ids.device if input_ids is not None else inputs_embeds.device
  837. dtype = input_ids.dtype if input_ids is not None else inputs_embeds.dtype
  838. patch_lengths = process_patch_lengths(
  839. torch.ones((batch_size, sequence_length + 1), dtype=dtype, device=device),
  840. self.config.max_patch_length,
  841. )
  842. patch_ids = self._patch_ids_from_lengths(patch_lengths, sequence_length)
  843. if position_ids is None:
  844. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  845. position_ids = torch.arange(encoder_embeds.shape[1], device=encoder_embeds.device) + past_seen_tokens
  846. position_ids = position_ids.unsqueeze(0)
  847. causal_mask = create_causal_mask(
  848. config=self.config,
  849. inputs_embeds=encoder_embeds,
  850. attention_mask=attention_mask,
  851. past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
  852. position_ids=position_ids,
  853. )
  854. cross_attn_mask_enc = _prepare_patch_cross_attention_mask(
  855. patch_ids=patch_ids,
  856. num_patches=patch_lengths.shape[1],
  857. sequence_length=sequence_length,
  858. patches_as_queries=True,
  859. cross_attn_k=self.config.cross_attn_k,
  860. dtype=encoder_embeds.dtype,
  861. )
  862. encoder_hidden_states, encoder_cross_states = self.local_encoder(
  863. input_ids=input_ids,
  864. inputs_embeds=encoder_embeds,
  865. attention_mask=causal_mask,
  866. position_ids=position_ids,
  867. encoder_attention_mask=cross_attn_mask_enc,
  868. num_patches=patch_lengths.shape[1],
  869. patch_ids=patch_ids,
  870. past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None,
  871. **kwargs,
  872. )
  873. encoder_cross_states = encoder_cross_states.view(batch_size, patch_lengths.shape[1], -1)
  874. global_position_ids = torch.arange(0, encoder_cross_states.shape[1], device=encoder_cross_states.device)
  875. global_position_ids = global_position_ids.unsqueeze(0)
  876. global_causal_mask = create_causal_mask(
  877. config=self.config,
  878. inputs_embeds=encoder_cross_states,
  879. attention_mask=None,
  880. past_key_values=None,
  881. position_ids=None,
  882. )
  883. global_hidden_states = self.global_transformer(
  884. inputs_embeds=encoder_cross_states,
  885. attention_mask=global_causal_mask,
  886. position_ids=global_position_ids,
  887. **kwargs,
  888. )
  889. decoder_patch_ids = self._patch_ids_from_lengths(patch_lengths[:, 1:], sequence_length)
  890. cross_attn_mask_dec = _prepare_patch_cross_attention_mask(
  891. patch_ids=decoder_patch_ids,
  892. num_patches=patch_lengths.shape[1],
  893. sequence_length=sequence_length,
  894. patches_as_queries=False,
  895. cross_attn_k=self.config.cross_attn_k,
  896. dtype=encoder_embeds.dtype,
  897. )
  898. output = self.local_decoder(
  899. input_ids=input_ids,
  900. inputs_embeds=encoder_hidden_states,
  901. patch_embeds=global_hidden_states,
  902. attention_mask=causal_mask,
  903. position_ids=position_ids,
  904. past_key_values=past_key_values.cross_attention_cache if past_key_values is not None else None,
  905. encoder_attention_mask=cross_attn_mask_dec,
  906. **kwargs,
  907. )
  908. return BaseModelOutputWithPast(
  909. last_hidden_state=output,
  910. past_key_values=past_key_values,
  911. )
  912. def get_input_embeddings(self):
  913. return self.local_encoder.embed_tokens
  914. def set_input_embeddings(self, value):
  915. self.local_encoder.embed_tokens = value
  916. def _patch_ids_from_lengths(self, patch_lengths: torch.Tensor, seq_len: int) -> torch.Tensor:
  917. batch_size = patch_lengths.shape[0]
  918. patch_starts = torch.cat(
  919. [
  920. torch.zeros(batch_size, 1, dtype=patch_lengths.dtype, device=patch_lengths.device),
  921. patch_lengths.cumsum(dim=-1)[:, :-1],
  922. ],
  923. dim=-1,
  924. )
  925. token_positions = torch.arange(seq_len, device=patch_lengths.device)
  926. return (patch_starts.unsqueeze(1) <= token_positions.unsqueeze(0).unsqueeze(-1)).sum(dim=-1) - 1
  927. @auto_docstring(
  928. custom_intro="""
  929. The Blt Text Model with a language modeling head on top.
  930. """
  931. )
  932. class BltForCausalLM(BltPreTrainedModel, GenerationMixin):
  933. config: BltConfig
  934. _can_compile_fullgraph = False
  935. base_model_prefix = "model"
  936. _tied_weights_keys = {"model.local_encoder.embed_tokens.weight": "lm_head.weight"}
  937. def __init__(self, config: BltConfig):
  938. super().__init__(config)
  939. self.text_config = config.get_text_config()
  940. self.vocab_size = config.vocab_size
  941. self.model = BltModel(config)
  942. self.lm_head = nn.Linear(config.decoder_config.hidden_size, config.vocab_size, bias=False)
  943. self.post_init()
  944. @can_return_tuple
  945. @auto_docstring
  946. def forward(
  947. self,
  948. input_ids: torch.LongTensor | None = None,
  949. attention_mask: torch.Tensor | None = None,
  950. position_ids: torch.LongTensor | None = None,
  951. cross_attention_states: torch.LongTensor | None = None, # Keep for compatibility
  952. cross_attention_mask: torch.LongTensor | None = None,
  953. full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor] | None = None,
  954. past_key_values: Cache | None = None,
  955. inputs_embeds: torch.FloatTensor | None = None,
  956. labels: torch.LongTensor | None = None,
  957. use_cache: bool | None = None,
  958. logits_to_keep: int | torch.Tensor = 0,
  959. **kwargs: Unpack[TransformersKwargs],
  960. ) -> tuple | CausalLMOutputWithPast:
  961. r"""
  962. cross_attention_states (`torch.FloatTensor`, *optional*):
  963. Output of the vision model, used for cross-attention. This tensor contains the processed image features that
  964. the language model will attend to.
  965. cross_attention_mask (`torch.Tensor` of shape `(batch_size, seq_length, max_num_images, max_num_tiles)`, *optional*):
  966. Cross-attention mask to control the interaction between text tokens and image tiles.
  967. This 4D tensor defines which image tiles each text token should attend to.
  968. For each text token (in seq_length):
  969. - 1 indicates the token **should attend** to the corresponding image tile
  970. - 0 indicates the token **should not attend** to the corresponding image tile
  971. full_text_row_masked_out_mask (`tuple[torch.Tensor, torch.Tensor]`, *optional*):
  972. A tuple containing two tensors that mask out rows in the cross-attention mechanism:
  973. - The first tensor has shape `(batch_size, 1, seq_length, 1)` and contains values of 0 or 1.
  974. A value of 0 indicates that the corresponding text token's entire row in the cross-attention
  975. matrix should be masked out (all image tokens ignored).
  976. - The second tensor has the same shape and is used internally to apply the masking during
  977. the forward pass of cross-attention layers.
  978. This mask is derived from the cross_attention_mask and is used to handle cases where a text token
  979. should not attend to any image token.
  980. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  981. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  982. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  983. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  984. Example:
  985. ```python
  986. >>> from transformers import AutoTokenizer, BltForCausalLM
  987. >>> model = BltForCausalLM.from_pretrained("itazap/blt-1b-hf")
  988. >>> tokenizer = AutoTokenizer.from_pretrained("itazap/blt-1b-hf")
  989. >>> prompt = "If I had to write a haiku, it would be:"
  990. >>> inputs = tokenizer(prompt, return_tensors="pt")
  991. >>> # Generate
  992. >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6)
  993. >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  994. >>> print(result)
  995. If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful.
  996. I love the idea of snowflakes gently falling, each one
  997. ```
  998. """
  999. # Call parent forward but exclude cross_attention_states from model call
  1000. outputs = self.model(
  1001. input_ids=input_ids,
  1002. attention_mask=attention_mask,
  1003. position_ids=position_ids,
  1004. cross_attention_mask=cross_attention_mask,
  1005. full_text_row_masked_out_mask=full_text_row_masked_out_mask,
  1006. past_key_values=past_key_values,
  1007. inputs_embeds=inputs_embeds,
  1008. use_cache=use_cache,
  1009. **kwargs,
  1010. )
  1011. hidden_states = outputs.last_hidden_state
  1012. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  1013. logits = self.lm_head(hidden_states[:, slice_indices, :]).float()
  1014. loss = None
  1015. if labels is not None:
  1016. loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
  1017. return CausalLMOutputWithPast(
  1018. loss=loss,
  1019. logits=logits,
  1020. past_key_values=outputs.past_key_values,
  1021. hidden_states=outputs.hidden_states,
  1022. attentions=outputs.attentions,
  1023. )
  1024. __all__ = [
  1025. "BltPreTrainedModel",
  1026. "BltModel",
  1027. "BltPatcher",
  1028. "BltForCausalLM",
  1029. ]