modular_glm4v.py 61 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421
  1. # Copyright 2025 The ZhipuAI Inc. team and HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import itertools
  15. from collections.abc import Callable
  16. import numpy as np
  17. import torch
  18. import torch.nn as nn
  19. import torch.nn.functional as F
  20. from huggingface_hub.dataclasses import strict
  21. from torch.nn import LayerNorm
  22. from ... import initialization as init
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache
  25. from ...configuration_utils import PreTrainedConfig
  26. from ...feature_extraction_utils import BatchFeature
  27. from ...image_utils import ImageInput
  28. from ...masking_utils import create_causal_mask
  29. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  30. from ...modeling_layers import GradientCheckpointingLayer
  31. from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
  32. from ...modeling_rope_utils import RopeParameters
  33. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  34. from ...processing_utils import Unpack
  35. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  36. from ...utils import (
  37. TransformersKwargs,
  38. auto_docstring,
  39. can_return_tuple,
  40. logging,
  41. torch_compilable_check,
  42. )
  43. from ...utils.generic import maybe_autocast, merge_with_config_defaults
  44. from ...utils.output_capturing import capture_outputs
  45. from ...video_utils import VideoInput
  46. from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward
  47. from ..qwen2_5_vl.modeling_qwen2_5_vl import (
  48. Qwen2_5_VisionPatchEmbed,
  49. Qwen2_5_VisionRotaryEmbedding,
  50. Qwen2_5_VLCausalLMOutputWithPast,
  51. Qwen2_5_VLForConditionalGeneration,
  52. Qwen2_5_VLMLP,
  53. Qwen2_5_VLModelOutputWithPast,
  54. Qwen2_5_VLPreTrainedModel,
  55. Qwen2_5_VLTextModel,
  56. Qwen2_5_VLVisionAttention,
  57. Qwen2_5_VLVisionBlock,
  58. )
  59. from ..qwen2_vl.modeling_qwen2_vl import Qwen2VLModel
  60. from ..qwen2_vl.processing_qwen2_vl import (
  61. Qwen2VLProcessor,
  62. Qwen2VLProcessorKwargs,
  63. )
  64. logger = logging.get_logger(__name__)
  65. @auto_docstring(checkpoint="zai-org/GLM-4.1V-9B-Thinking")
  66. @strict
  67. class Glm4vVisionConfig(PreTrainedConfig):
  68. r"""
  69. out_hidden_size (`int`, *optional*, defaults to 4096):
  70. The output hidden size of the vision model.
  71. Example:
  72. ```python
  73. >>> from transformers import Glm4vVisionConfig, Glm4vVisionModel
  74. >>> # Initializing a Glm4vVisionConfig GLM-4.1V-9B style configuration
  75. >>> configuration = Glm4vVisionConfig()
  76. >>> # Initializing a model (with random weights) from the GLM-4.1V-9B configuration
  77. >>> model = Glm4vVisionModel(configuration)
  78. >>> # Accessing the model configuration
  79. >>> configuration = model.config
  80. ```"""
  81. model_type = "glm4v_vision"
  82. base_config_key = "vision_config"
  83. depth: int = 24
  84. hidden_size: int = 1536
  85. hidden_act: str = "silu"
  86. attention_bias: bool = False
  87. attention_dropout: float | int = 0.0
  88. num_heads: int = 12
  89. in_channels: int = 3
  90. image_size: int | list[int] | tuple[int, int] = 336
  91. patch_size: int | list[int] | tuple[int, int] = 14
  92. rms_norm_eps: float = 1e-05
  93. spatial_merge_size: int = 2
  94. temporal_patch_size: int | list[int] | tuple[int, int] = 2
  95. out_hidden_size: int = 4096
  96. intermediate_size: int = 13696
  97. initializer_range: float = 0.02
  98. @auto_docstring(checkpoint="zai-org/GLM-4.1V-9B-Thinking")
  99. @strict
  100. class Glm4vTextConfig(PreTrainedConfig):
  101. r"""
  102. Example:
  103. ```python
  104. >>> from transformers import Glm4vTextModel, Glm4vConfig
  105. >>> # Initializing a GLM-4.1V style configuration
  106. >>> configuration = Glm4vConfig()
  107. >>> # Initializing a model from the GLM-4.1V style configuration
  108. >>> model = Glm4vTextModel(configuration)
  109. >>> # Accessing the model configuration
  110. >>> configuration = model.config
  111. ```"""
  112. model_type = "glm4v_text"
  113. base_config_key = "text_config"
  114. keys_to_ignore_at_inference = ["past_key_values"]
  115. # Default tensor parallel plan for base model `Glm4v`
  116. base_model_tp_plan = {
  117. "layers.*.self_attn.q_proj": "colwise",
  118. "layers.*.self_attn.k_proj": "colwise",
  119. "layers.*.self_attn.v_proj": "colwise",
  120. "layers.*.self_attn.o_proj": "rowwise",
  121. "layers.*.mlp.gate_up_proj": "colwise_gather_output", # we need to replicate here due to the `chunk` operation
  122. "layers.*.mlp.down_proj": "rowwise_split_input", # input is replicated due to the `chunk` operation
  123. }
  124. base_model_pp_plan = {
  125. "embed_tokens": (["input_ids"], ["inputs_embeds"]),
  126. "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
  127. "norm": (["hidden_states"], ["hidden_states"]),
  128. }
  129. ignore_keys_at_rope_validation = {"mrope_section"}
  130. vocab_size: int = 151552
  131. hidden_size: int = 4096
  132. intermediate_size: int = 13696
  133. num_hidden_layers: int = 40
  134. num_attention_heads: int = 32
  135. num_key_value_heads: int | None = 2
  136. hidden_act: str = "silu"
  137. max_position_embeddings: int = 32768
  138. initializer_range: float = 0.02
  139. rms_norm_eps: float = 1e-05
  140. use_cache: bool = True
  141. attention_dropout: float | int = 0.0
  142. rope_parameters: RopeParameters | dict | None = None
  143. pad_token_id: int | None = None
  144. def __post_init__(self, **kwargs):
  145. if self.num_key_value_heads is None:
  146. self.num_key_value_heads = self.num_attention_heads
  147. super().__post_init__(**kwargs)
  148. @auto_docstring(checkpoint="zai-org/GLM-4.1V-9B-Thinking")
  149. @strict
  150. class Glm4vConfig(PreTrainedConfig):
  151. r"""
  152. image_start_token_id (`int`, *optional*, defaults to 151339):
  153. The image start token index to encode the start of image.
  154. image_end_token_id (`int`, *optional*, defaults to 151340):
  155. The image end token index to encode the end of image.
  156. video_start_token_id (`int`, *optional*, defaults to 151341):
  157. The video start token index to encode the start of video.
  158. video_end_token_id (`int`, *optional*, defaults to 151342):
  159. The video end token index to encode the end of video.
  160. ```python
  161. >>> from transformers import Glm4vForConditionalGeneration, Glm4vConfig
  162. >>> # Initializing a GLM-4.1V style configuration
  163. >>> configuration = Glm4vConfig()
  164. >>> # Initializing a model from the GLM-4.1V style configuration
  165. >>> model = Glm4vForConditionalGeneration(configuration)
  166. >>> # Accessing the model configuration
  167. >>> configuration = model.config
  168. ```"""
  169. model_type = "glm4v"
  170. sub_configs = {"vision_config": Glm4vVisionConfig, "text_config": Glm4vTextConfig}
  171. keys_to_ignore_at_inference = ["past_key_values"]
  172. text_config: dict | PreTrainedConfig | None = None
  173. vision_config: dict | PreTrainedConfig | None = None
  174. image_token_id: int = 151343
  175. video_token_id: int = 151344
  176. image_start_token_id: int = 151339
  177. image_end_token_id: int = 151340
  178. video_start_token_id: int = 151341
  179. video_end_token_id: int = 151342
  180. tie_word_embeddings: bool = False
  181. def __post_init__(self, **kwargs):
  182. if isinstance(self.vision_config, dict):
  183. self.vision_config = self.sub_configs["vision_config"](**self.vision_config)
  184. elif self.vision_config is None:
  185. self.vision_config = self.sub_configs["vision_config"](**kwargs)
  186. if isinstance(self.text_config, dict):
  187. self.text_config = self.sub_configs["text_config"](**self.text_config)
  188. elif self.text_config is None:
  189. self.text_config = self.sub_configs["text_config"](**kwargs)
  190. super().__post_init__(**kwargs)
  191. # Will be used for both Text and Vision modalities
  192. class Glm4vRMSNorm(Glm4RMSNorm):
  193. pass
  194. class Glm4VisionMlp(Qwen2_5_VLMLP):
  195. def __init__(self, config, bias: bool = False):
  196. super().__init__(config, bias)
  197. self.intermediate_size = config.out_hidden_size
  198. class Glm4vVisionPatchEmbed(Qwen2_5_VisionPatchEmbed):
  199. def __init__(self, config: Glm4vVisionConfig) -> None:
  200. nn.Module.__init__(self)
  201. self.patch_size = config.patch_size
  202. self.temporal_patch_size = config.temporal_patch_size
  203. self.in_channels = config.in_channels
  204. self.embed_dim = config.hidden_size
  205. kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
  206. self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
  207. class Glm4vVisionRotaryEmbedding(Qwen2_5_VisionRotaryEmbedding):
  208. pass
  209. class Glm4vVisionPatchMerger(nn.Module):
  210. def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
  211. super().__init__()
  212. self.proj = nn.Linear(dim, dim, bias=bias)
  213. self.post_projection_norm = LayerNorm(dim)
  214. self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
  215. self.up_proj = nn.Linear(dim, context_dim, bias=bias)
  216. self.down_proj = nn.Linear(context_dim, dim, bias=bias)
  217. self.act1 = nn.GELU()
  218. self.act_fn = ACT2FN[hidden_act]
  219. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  220. hidden_state = self.proj(hidden_state)
  221. hidden_state = self.act1(self.post_projection_norm(hidden_state))
  222. return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
  223. class Glm4vVisionEmbeddings(nn.Module):
  224. def __init__(self, config: Glm4vVisionConfig):
  225. super().__init__()
  226. self.config = config
  227. self.embed_dim = config.hidden_size
  228. self.image_size = config.image_size
  229. self.patch_size = config.patch_size
  230. self.num_patches = (self.image_size // self.patch_size) ** 2
  231. self.num_positions = self.num_patches
  232. self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
  233. self.interpolated_method = "bicubic"
  234. def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
  235. """
  236. Forward pass with integrated position encoding adaptation using 2D interpolation.
  237. Args:
  238. embeddings: Input embeddings tensor
  239. lengths (torch.Tensor): Sequence lengths for each image in the batch.
  240. image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
  241. h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
  242. w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
  243. Returns:
  244. torch.Tensor: Embeddings with adapted position encoding added.
  245. """
  246. # Get position embedding parameters
  247. pos_embed_weight = self.position_embedding.weight
  248. hidden_size = pos_embed_weight.shape[1]
  249. device = pos_embed_weight.device
  250. # Convert inputs to tensors if needed
  251. if isinstance(lengths, list):
  252. lengths = torch.tensor(lengths, device=device, dtype=torch.long)
  253. # Prepare 2D position embedding
  254. orig_size_sq = pos_embed_weight.shape[0]
  255. orig_size = int(orig_size_sq**0.5)
  256. pos_embed_2d = (
  257. pos_embed_weight.view(orig_size, orig_size, hidden_size)
  258. .permute(2, 0, 1)
  259. .unsqueeze(0)
  260. .to(device=device, dtype=torch.float32)
  261. )
  262. # Calculate target dimensions for each patch
  263. target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
  264. device=device, dtype=torch.float32
  265. )
  266. target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
  267. device=device, dtype=torch.float32
  268. )
  269. # Normalize coordinates to [-1, 1] range for grid_sample
  270. norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
  271. norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
  272. # Create sampling grid
  273. grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
  274. # Perform bicubic interpolation
  275. interpolated_embed_fp32 = F.grid_sample(
  276. pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
  277. )
  278. # Reshape and convert back to original dtype
  279. adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
  280. adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
  281. # Add adapted position encoding to embeddings
  282. embeddings = embeddings + adapted_pos_embed
  283. return embeddings
  284. class Glm4vVisionAttention(Qwen2_5_VLVisionAttention):
  285. def __init__(self, config: Glm4vVisionConfig) -> None:
  286. super().__init__(config)
  287. self.attention_dropout = config.attention_dropout
  288. self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
  289. self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  290. class Glm4vVisionBlock(Qwen2_5_VLVisionBlock):
  291. def __init__(self, config) -> None:
  292. super().__init__(config)
  293. self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  294. self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  295. self.attn = Glm4vVisionAttention(config)
  296. self.mlp = Glm4VisionMlp(config, bias=False)
  297. class Glm4vTextRotaryEmbedding(Glm4RotaryEmbedding):
  298. def __init__(self, config: Glm4vTextConfig, device=None):
  299. super().__init__()
  300. self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
  301. def forward(self, x, position_ids):
  302. # In contrast to other models, GLM-V has different position ids for the grids
  303. # So we expand the inv_freq to shape (3, ...)
  304. inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
  305. position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
  306. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  307. with maybe_autocast(device_type=device_type, enabled=False): # Force float32
  308. freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
  309. freqs = self.apply_mrope(freqs, self.mrope_section)
  310. emb = torch.cat((freqs, freqs), dim=-1)
  311. cos = emb.cos() * self.attention_scaling
  312. sin = emb.sin() * self.attention_scaling
  313. return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  314. def apply_mrope(self, freqs, mrope_section):
  315. section = mrope_section
  316. chunks = freqs.split(section, dim=-1)
  317. result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
  318. return result
  319. def rotate_half_llm(x):
  320. """Rotates half the hidden dims of the input."""
  321. x1 = x[..., 0::2]
  322. x2 = x[..., 1::2]
  323. return torch.stack((-x2, x1), dim=-1).flatten(-2)
  324. def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
  325. """Applies Rotary Position Embedding to the query and key tensors.
  326. Args:
  327. q (`torch.Tensor`): The query tensor.
  328. k (`torch.Tensor`): The key tensor.
  329. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  330. sin (`torch.Tensor`): The sine part of the rotary embedding.
  331. unsqueeze_dim (`int`, *optional*, defaults to 1):
  332. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  333. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  334. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  335. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  336. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  337. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  338. Returns:
  339. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  340. """
  341. cos = cos.unsqueeze(unsqueeze_dim)
  342. sin = sin.unsqueeze(unsqueeze_dim)
  343. # Interleave them instead of usual shape
  344. cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
  345. sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
  346. # Keep half or full tensor for later concatenation
  347. rotary_dim = cos.shape[-1]
  348. q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
  349. k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
  350. # Apply rotary embeddings on the first half or full tensor
  351. q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
  352. k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
  353. # Concatenate back to full shape
  354. q_embed = torch.cat([q_embed, q_pass], dim=-1)
  355. k_embed = torch.cat([k_embed, k_pass], dim=-1)
  356. return q_embed, k_embed
  357. class Glm4vTextAttention(nn.Module):
  358. """
  359. Multi-headed attention from 'Attention Is All You Need' paper.
  360. and "Generating Long Sequences with Sparse Transformers".
  361. """
  362. def __init__(self, config: Glm4vTextConfig, layer_idx: int | None = None):
  363. super().__init__()
  364. self.config = config
  365. self.layer_idx = layer_idx
  366. self.hidden_size = config.hidden_size
  367. self.num_heads = config.num_attention_heads
  368. self.head_dim = self.hidden_size // self.num_heads
  369. self.num_key_value_heads = config.num_key_value_heads
  370. self.num_key_value_groups = self.num_heads // self.num_key_value_heads
  371. self.is_causal = True
  372. self.attention_dropout = config.attention_dropout
  373. self.rope_parameters = config.rope_parameters
  374. self.scaling = self.head_dim**-0.5
  375. self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
  376. self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
  377. self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
  378. self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
  379. def forward(
  380. self,
  381. hidden_states: torch.Tensor,
  382. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  383. attention_mask: torch.Tensor | None = None,
  384. past_key_values: Cache | None = None,
  385. **kwargs: Unpack[FlashAttentionKwargs],
  386. ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
  387. bsz, q_len, _ = hidden_states.size()
  388. query_states = self.q_proj(hidden_states)
  389. key_states = self.k_proj(hidden_states)
  390. value_states = self.v_proj(hidden_states)
  391. query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
  392. key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
  393. value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
  394. cos, sin = position_embeddings
  395. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  396. if past_key_values is not None:
  397. key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
  398. attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
  399. self.config._attn_implementation, eager_attention_forward
  400. )
  401. attn_output, attn_weights = attention_interface(
  402. self,
  403. query_states,
  404. key_states,
  405. value_states,
  406. attention_mask,
  407. dropout=0.0 if not self.training else self.attention_dropout,
  408. scaling=self.scaling,
  409. **kwargs,
  410. )
  411. attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
  412. attn_output = self.o_proj(attn_output)
  413. return attn_output, attn_weights
  414. class Glm4vTextMLP(Glm4MLP):
  415. pass
  416. class Glm4vTextDecoderLayer(GradientCheckpointingLayer):
  417. def __init__(self, config: Glm4vTextConfig, layer_idx: int):
  418. super().__init__()
  419. self.hidden_size = config.hidden_size
  420. self.self_attn = Glm4vTextAttention(config, layer_idx)
  421. self.mlp = Glm4vTextMLP(config)
  422. self.input_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  423. self.post_attention_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  424. self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  425. self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  426. @auto_docstring
  427. def forward(
  428. self,
  429. hidden_states: torch.Tensor,
  430. position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
  431. attention_mask: torch.Tensor | None = None,
  432. position_ids: torch.LongTensor | None = None,
  433. past_key_values: Cache | None = None,
  434. use_cache: bool | None = False,
  435. **kwargs,
  436. ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
  437. residual = hidden_states
  438. hidden_states = self.input_layernorm(hidden_states)
  439. # Self Attention
  440. hidden_states, _ = self.self_attn(
  441. hidden_states=hidden_states,
  442. position_embeddings=position_embeddings,
  443. attention_mask=attention_mask,
  444. position_ids=position_ids,
  445. past_key_values=past_key_values,
  446. use_cache=use_cache,
  447. **kwargs,
  448. )
  449. hidden_states = self.post_self_attn_layernorm(hidden_states)
  450. hidden_states = residual + hidden_states
  451. # Fully Connected
  452. residual = hidden_states
  453. hidden_states = self.post_attention_layernorm(hidden_states)
  454. hidden_states = self.mlp(hidden_states)
  455. hidden_states = self.post_mlp_layernorm(hidden_states)
  456. hidden_states = residual + hidden_states
  457. return hidden_states
  458. class Glm4vModelOutputWithPast(Qwen2_5_VLModelOutputWithPast):
  459. pass
  460. class Glm4vPreTrainedModel(Qwen2_5_VLPreTrainedModel):
  461. _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
  462. def _init_weights(self, module):
  463. PreTrainedModel._init_weights(self, module)
  464. if isinstance(module, Glm4vVisionRotaryEmbedding):
  465. inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
  466. init.copy_(module.inv_freq, inv_freq)
  467. class Glm4vVisionModel(Glm4vPreTrainedModel):
  468. config: Glm4vVisionConfig
  469. input_modalities = ("image", "video")
  470. _no_split_modules = ["Glm4vVisionBlock"]
  471. _can_record_outputs = {
  472. "hidden_states": Glm4vVisionBlock,
  473. "attentions": Glm4vVisionAttention,
  474. }
  475. def __init__(self, config) -> None:
  476. super().__init__(config)
  477. self.spatial_merge_size = config.spatial_merge_size
  478. self.patch_size = config.patch_size
  479. self.embeddings = Glm4vVisionEmbeddings(config)
  480. self.patch_embed = Glm4vVisionPatchEmbed(config)
  481. head_dim = config.hidden_size // config.num_heads
  482. self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
  483. self.blocks = nn.ModuleList([Glm4vVisionBlock(config) for _ in range(config.depth)])
  484. self.merger = Glm4vVisionPatchMerger(
  485. dim=config.out_hidden_size, context_dim=config.intermediate_size, hidden_act=config.hidden_act
  486. )
  487. self.post_conv_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  488. self.downsample = nn.Conv2d(
  489. in_channels=config.hidden_size,
  490. out_channels=config.out_hidden_size,
  491. kernel_size=config.spatial_merge_size,
  492. stride=config.spatial_merge_size,
  493. )
  494. self.post_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  495. self.gradient_checkpointing = False
  496. self.post_init()
  497. def rot_pos_emb(self, grid_thw):
  498. pos_ids = []
  499. for t, h, w in grid_thw:
  500. hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
  501. hpos_ids = hpos_ids.reshape(
  502. h // self.spatial_merge_size,
  503. self.spatial_merge_size,
  504. w // self.spatial_merge_size,
  505. self.spatial_merge_size,
  506. )
  507. hpos_ids = hpos_ids.permute(0, 2, 1, 3)
  508. hpos_ids = hpos_ids.flatten()
  509. wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
  510. wpos_ids = wpos_ids.reshape(
  511. h // self.spatial_merge_size,
  512. self.spatial_merge_size,
  513. w // self.spatial_merge_size,
  514. self.spatial_merge_size,
  515. )
  516. wpos_ids = wpos_ids.permute(0, 2, 1, 3)
  517. wpos_ids = wpos_ids.flatten()
  518. pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
  519. pos_ids = torch.cat(pos_ids, dim=0)
  520. max_grid_size = grid_thw[:, 1:].max()
  521. rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
  522. rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
  523. return rotary_pos_emb, pos_ids
  524. @merge_with_config_defaults
  525. @capture_outputs
  526. @auto_docstring
  527. def forward(
  528. self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
  529. ) -> tuple | BaseModelOutputWithPooling:
  530. r"""
  531. hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
  532. The final hidden states of the model.
  533. grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
  534. The temporal, height and width of feature shape of each image in LLM.
  535. Returns:
  536. `torch.Tensor`: hidden_states.
  537. """
  538. hidden_states = self.patch_embed(hidden_states)
  539. hidden_states = self.post_conv_layernorm(hidden_states)
  540. rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
  541. emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
  542. position_embeddings = (emb.cos(), emb.sin())
  543. cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
  544. dim=0,
  545. # Select dtype based on the following factors:
  546. # - FA2 requires that cu_seqlens_q must have dtype int32
  547. # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
  548. # See https://github.com/huggingface/transformers/pull/34852 for more information
  549. dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
  550. )
  551. cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
  552. seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
  553. hidden_states = self.embeddings(
  554. hidden_states,
  555. seqlens,
  556. grid_thw,
  557. image_type_ids[:, 0].to(hidden_states.device),
  558. image_type_ids[:, 1].to(hidden_states.device),
  559. )
  560. for blk in self.blocks:
  561. hidden_states = blk(
  562. hidden_states,
  563. cu_seqlens=cu_seqlens,
  564. position_embeddings=position_embeddings,
  565. **kwargs,
  566. )
  567. hidden_states = self.post_layernorm(hidden_states)
  568. hidden_states = hidden_states.view(
  569. -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
  570. )
  571. hidden_states = hidden_states.permute(0, 3, 1, 2)
  572. hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)
  573. merged_hidden_states = self.merger(hidden_states)
  574. return BaseModelOutputWithPooling(
  575. last_hidden_state=hidden_states,
  576. pooler_output=merged_hidden_states,
  577. )
  578. class Glm4vTextModel(Qwen2_5_VLTextModel):
  579. _can_record_outputs = {
  580. "hidden_states": Glm4vTextDecoderLayer,
  581. "attentions": Glm4vTextAttention,
  582. }
  583. def __init__(self, config: Glm4vTextConfig):
  584. super().__init__(config)
  585. self.layers = nn.ModuleList(
  586. [Glm4vTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
  587. )
  588. self.norm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
  589. self.rotary_emb = Glm4vTextRotaryEmbedding(config=config)
  590. del self._attn_implementation
  591. del self.has_sliding_layers
  592. @auto_docstring
  593. @merge_with_config_defaults
  594. @capture_outputs
  595. def forward(
  596. self,
  597. input_ids: torch.LongTensor | None = None,
  598. attention_mask: torch.Tensor | None = None,
  599. position_ids: torch.LongTensor | None = None,
  600. past_key_values: Cache | None = None,
  601. inputs_embeds: torch.FloatTensor | None = None,
  602. use_cache: bool | None = None,
  603. **kwargs: Unpack[FlashAttentionKwargs],
  604. ) -> tuple | BaseModelOutputWithPast:
  605. if (input_ids is None) ^ (inputs_embeds is not None):
  606. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  607. # torch.jit.trace() doesn't support cache objects in the output
  608. if use_cache and past_key_values is None and not torch.jit.is_tracing():
  609. past_key_values = DynamicCache(config=self.config)
  610. if inputs_embeds is None:
  611. inputs_embeds = self.embed_tokens(input_ids)
  612. # the hard coded `3` is for temporal, height and width.
  613. if position_ids is None:
  614. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  615. position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
  616. position_ids = position_ids.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
  617. elif position_ids.ndim == 2:
  618. position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
  619. # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
  620. # where each dim indicates visual spatial positions for temporal/height/width grids.
  621. # There are two scenarios when FA2-like packed masking might be activated.
  622. # 1. User specifically passed packed `position_ids` and no attention mask.
  623. # In this case we expect the useer to create correct position ids for all 3 grids
  624. # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
  625. # 2. User runs forward with no attention mask and no position ids. In this case, position ids
  626. # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
  627. # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
  628. # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
  629. if position_ids.ndim == 3 and position_ids.shape[0] == 4:
  630. text_position_ids = position_ids[0]
  631. position_ids = position_ids[1:]
  632. else:
  633. # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
  634. text_position_ids = None
  635. mask_kwargs = {
  636. "config": self.config,
  637. "inputs_embeds": inputs_embeds,
  638. "attention_mask": attention_mask,
  639. "past_key_values": past_key_values,
  640. "position_ids": text_position_ids,
  641. }
  642. # Create the masks
  643. causal_mask = create_causal_mask(**mask_kwargs)
  644. hidden_states = inputs_embeds
  645. position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
  646. for decoder_layer in self.layers:
  647. layer_outputs = decoder_layer(
  648. hidden_states,
  649. attention_mask=causal_mask,
  650. position_ids=text_position_ids,
  651. past_key_values=past_key_values,
  652. position_embeddings=position_embeddings,
  653. **kwargs,
  654. )
  655. hidden_states = layer_outputs
  656. hidden_states = self.norm(hidden_states)
  657. return BaseModelOutputWithPast(
  658. last_hidden_state=hidden_states,
  659. past_key_values=past_key_values,
  660. )
  661. class Glm4vModel(Qwen2VLModel):
  662. _no_split_modules = ["Glm4vTextDecoderLayer", "Glm4vVisionBlock"]
  663. def __init__(self, config):
  664. super().__init__(config)
  665. self.visual = Glm4vVisionModel._from_config(config.vision_config)
  666. @can_return_tuple
  667. @auto_docstring
  668. def get_video_features(
  669. self,
  670. pixel_values_videos: torch.FloatTensor,
  671. video_grid_thw: torch.LongTensor | None = None,
  672. **kwargs: Unpack[TransformersKwargs],
  673. ) -> tuple | BaseModelOutputWithPooling:
  674. r"""
  675. pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  676. The tensors corresponding to the input videos.
  677. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  678. The temporal, height and width of feature shape of each video in LLM.
  679. """
  680. pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
  681. # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
  682. temp_frames_hw = []
  683. video_grid_thw_list = video_grid_thw.tolist()
  684. for t, h, w in video_grid_thw_list:
  685. repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1)
  686. temp_frames_hw.append(repeated_row)
  687. flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
  688. vision_outputs = self.visual(
  689. pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs
  690. )
  691. split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
  692. video_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
  693. vision_outputs.pooler_output = video_embeds
  694. return vision_outputs
  695. def get_placeholder_mask(
  696. self,
  697. input_ids: torch.LongTensor,
  698. inputs_embeds: torch.FloatTensor,
  699. image_features: torch.FloatTensor | None = None,
  700. video_features: torch.FloatTensor | None = None,
  701. ):
  702. """
  703. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  704. equal to the length of multimodal features. If the lengths are different, an error is raised.
  705. """
  706. if input_ids is None:
  707. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  708. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  709. )
  710. special_image_mask = special_image_mask.all(-1)
  711. special_video_mask = inputs_embeds == self.get_input_embeddings()(
  712. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  713. )
  714. special_video_mask = special_video_mask.all(-1)
  715. else:
  716. # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask
  717. special_image_mask = input_ids == self.config.image_token_id
  718. special_video_mask = input_ids == self.config.image_token_id
  719. n_image_tokens = special_image_mask.sum()
  720. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  721. if image_features is not None:
  722. torch_compilable_check(
  723. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  724. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
  725. )
  726. n_video_tokens = special_video_mask.sum()
  727. special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  728. if video_features is not None:
  729. torch_compilable_check(
  730. inputs_embeds[special_video_mask].numel() == video_features.numel(),
  731. f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
  732. )
  733. return special_image_mask, special_video_mask
  734. def get_rope_index(
  735. self,
  736. input_ids: torch.LongTensor,
  737. mm_token_type_ids: torch.IntTensor,
  738. image_grid_thw: torch.LongTensor | None = None,
  739. video_grid_thw: torch.LongTensor | None = None,
  740. attention_mask: torch.Tensor | None = None,
  741. **kwargs,
  742. ) -> tuple[torch.Tensor, torch.Tensor]:
  743. """
  744. Calculate the 3D rope index based on image and video's sizes. The utility expects a `vision + text`
  745. sequence and will error out otherwise. For pure text sequence, please rely on model's auto-inferred
  746. position ids. In a mixed vision + text sequence, vision tokens use 3D RoPE (temporal, height, width)
  747. while text tokens use standard 1D RoPE.
  748. Example:
  749. Temporal patches: 3; Height patches: 2; Width patches: 2
  750. Each vision input results in (temporal x height × width) positions. Here: 3 x 2 × 2 = 12 positions total.
  751. Temporal position IDs are spaced by:
  752. `interval = tokens_per_second * temporal_patch_size / fps`
  753. If fps = 1; tokens_per_second = 25; temporal_patch_size = 2, temporal IDs increase by 50 for each temporal patch:
  754. `[0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]`
  755. Height IDs repeat per row: `[0, 0, 1, 1, ...]`
  756. Width IDs alternate per column: `[0, 1, 0, 1, ...]`
  757. Text tokens follow standard 1D RoPE and the position IDs grow consequently with a step of `1`
  758. Args:
  759. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  760. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  761. it.
  762. mm_token_type_ids (`torch.IntTensor` of shape `(batch_size, sequence_length)`):
  763. Token type ids matching each modality to a different value in the input sequence, i.e. text (0), image (1), video (2).
  764. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  765. The temporal, height and width of feature shape of each image in LLM.
  766. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  767. The temporal, height and width of feature shape of each video in LLM.
  768. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  769. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  770. - 1 for tokens that are **not masked**,
  771. - 0 for tokens that are **masked**.
  772. Returns:
  773. position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
  774. mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
  775. """
  776. spatial_merge_size = self.config.vision_config.spatial_merge_size
  777. mrope_position_deltas = []
  778. position_ids = torch.zeros(
  779. 3,
  780. input_ids.shape[0],
  781. input_ids.shape[1],
  782. dtype=input_ids.dtype,
  783. device=input_ids.device,
  784. )
  785. grid_iters = {
  786. 1: iter(image_grid_thw) if image_grid_thw is not None else None,
  787. 2: iter(video_grid_thw) if video_grid_thw is not None else None,
  788. }
  789. for batch_idx, current_input_ids in enumerate(input_ids):
  790. input_token_type = mm_token_type_ids[batch_idx]
  791. if attention_mask is not None:
  792. current_input_ids = current_input_ids[attention_mask[batch_idx].bool()]
  793. input_token_type = input_token_type[attention_mask[batch_idx].bool()]
  794. input_type_group = []
  795. for key, group in itertools.groupby(enumerate(input_token_type.tolist()), lambda x: x[1]):
  796. group = list(group)
  797. start_index = group[0][0]
  798. end_index = group[-1][0] + 1
  799. input_type_group.append((key, start_index, end_index))
  800. current_pos = 0
  801. video_group_index = 0
  802. llm_pos_ids_list = []
  803. for modality_type, start_idx, end_idx in input_type_group:
  804. # text == 0
  805. if modality_type == 0:
  806. text_len = end_idx - start_idx
  807. llm_pos_ids_list.append(
  808. torch.arange(text_len, device=input_ids.device).view(1, -1).expand(3, -1) + current_pos
  809. )
  810. current_pos += text_len
  811. # image == 1, video == 2
  812. else:
  813. # GLM4V splits video into segments per frame but there's only one `grid_thw`
  814. # per whole video. We can't exhaus the iterator and have to re-use the grid
  815. # while processing the same video!
  816. if modality_type == 2:
  817. if video_group_index == 0:
  818. grid_thw = next(grid_iters[modality_type])
  819. video_group_index += 1
  820. video_group_index = 0 if video_group_index >= grid_thw[0] else video_group_index
  821. else:
  822. grid_thw = next(grid_iters[modality_type])
  823. # Videos are processed per frame separately, each temporal grid is always `1`
  824. temp_merge_size = grid_thw[0]
  825. vision_position_ids = self.get_vision_position_ids(
  826. current_pos, grid_thw, temp_merge_size, spatial_merge_size, device=input_ids.device
  827. )
  828. llm_pos_ids_list.append(vision_position_ids)
  829. current_pos += max(grid_thw[1], grid_thw[2]) // spatial_merge_size
  830. llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
  831. if attention_mask is not None:
  832. position_ids[:, batch_idx, attention_mask[batch_idx].bool()] = llm_positions.to(position_ids.device)
  833. else:
  834. position_ids[:, batch_idx] = llm_positions.to(position_ids.device)
  835. mrope_position_deltas.append(llm_positions.max() + 1 - len(current_input_ids))
  836. mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
  837. return position_ids, mrope_position_deltas
  838. @auto_docstring
  839. @can_return_tuple
  840. def forward(
  841. self,
  842. input_ids: torch.LongTensor | None = None,
  843. attention_mask: torch.Tensor | None = None,
  844. position_ids: torch.LongTensor | None = None,
  845. past_key_values: Cache | None = None,
  846. inputs_embeds: torch.FloatTensor | None = None,
  847. pixel_values: torch.Tensor | None = None,
  848. pixel_values_videos: torch.FloatTensor | None = None,
  849. image_grid_thw: torch.LongTensor | None = None,
  850. video_grid_thw: torch.LongTensor | None = None,
  851. rope_deltas: torch.LongTensor | None = None,
  852. mm_token_type_ids: torch.IntTensor | None = None,
  853. **kwargs: Unpack[TransformersKwargs],
  854. ) -> tuple | Glm4vModelOutputWithPast:
  855. r"""
  856. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  857. The temporal, height and width of feature shape of each image in LLM.
  858. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  859. The temporal, height and width of feature shape of each video in LLM.
  860. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
  861. The rope index difference between sequence length and multimodal rope.
  862. """
  863. if (input_ids is None) ^ (inputs_embeds is not None):
  864. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  865. if inputs_embeds is None:
  866. inputs_embeds = self.get_input_embeddings()(input_ids)
  867. if pixel_values is not None:
  868. image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
  869. image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  870. image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
  871. inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
  872. if pixel_values_videos is not None:
  873. video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
  874. video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  875. _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
  876. inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
  877. if position_ids is None:
  878. position_ids = self.compute_3d_position_ids(
  879. input_ids=input_ids,
  880. image_grid_thw=image_grid_thw,
  881. video_grid_thw=video_grid_thw,
  882. inputs_embeds=inputs_embeds,
  883. attention_mask=attention_mask,
  884. past_key_values=past_key_values,
  885. mm_token_type_ids=mm_token_type_ids,
  886. )
  887. outputs = self.language_model(
  888. input_ids=None,
  889. position_ids=position_ids,
  890. attention_mask=attention_mask,
  891. past_key_values=past_key_values,
  892. inputs_embeds=inputs_embeds,
  893. **kwargs,
  894. )
  895. return Glm4vModelOutputWithPast(
  896. **outputs,
  897. rope_deltas=self.rope_deltas,
  898. )
  899. class Glm4vCausalLMOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
  900. pass
  901. class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
  902. def forward(
  903. self,
  904. input_ids: torch.LongTensor | None = None,
  905. attention_mask: torch.Tensor | None = None,
  906. position_ids: torch.LongTensor | None = None,
  907. past_key_values: Cache | None = None,
  908. inputs_embeds: torch.FloatTensor | None = None,
  909. labels: torch.LongTensor | None = None,
  910. pixel_values: torch.Tensor | None = None,
  911. pixel_values_videos: torch.FloatTensor | None = None,
  912. image_grid_thw: torch.LongTensor | None = None,
  913. video_grid_thw: torch.LongTensor | None = None,
  914. mm_token_type_ids: torch.IntTensor | None = None,
  915. logits_to_keep: int | torch.Tensor = 0,
  916. **kwargs: Unpack[TransformersKwargs],
  917. ) -> tuple | Glm4vCausalLMOutputWithPast:
  918. r"""
  919. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  920. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  921. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  922. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  923. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  924. The temporal, height and width of feature shape of each image in LLM.
  925. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  926. The temporal, height and width of feature shape of each video in LLM.
  927. Example:
  928. ```python
  929. >>> from PIL import Image
  930. >>> import httpx
  931. >>> from io import BytesIO
  932. >>> from transformers import AutoProcessor, Glm4vForConditionalGeneration
  933. >>> model = Glm4vForConditionalGeneration.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
  934. >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
  935. >>> messages = [
  936. {
  937. "role": "user",
  938. "content": [
  939. {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
  940. {"type": "text", "text": "What is shown in this image?"},
  941. ],
  942. },
  943. ]
  944. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  945. >>> with httpx.stream("GET", url) as response:
  946. ... image = Image.open(BytesIO(response.read()))
  947. >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  948. >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
  949. >>> # Generate
  950. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  951. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  952. "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
  953. ```"""
  954. outputs = self.model(
  955. input_ids=input_ids,
  956. pixel_values=pixel_values,
  957. pixel_values_videos=pixel_values_videos,
  958. image_grid_thw=image_grid_thw,
  959. video_grid_thw=video_grid_thw,
  960. mm_token_type_ids=mm_token_type_ids,
  961. position_ids=position_ids,
  962. attention_mask=attention_mask,
  963. past_key_values=past_key_values,
  964. inputs_embeds=inputs_embeds,
  965. **kwargs,
  966. )
  967. hidden_states = outputs[0]
  968. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  969. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  970. logits = self.lm_head(hidden_states[:, slice_indices, :])
  971. loss = None
  972. if labels is not None:
  973. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
  974. return Glm4vCausalLMOutputWithPast(
  975. loss=loss,
  976. logits=logits,
  977. past_key_values=outputs.past_key_values,
  978. hidden_states=outputs.hidden_states,
  979. attentions=outputs.attentions,
  980. rope_deltas=outputs.rope_deltas,
  981. )
  982. def prepare_inputs_for_generation(
  983. self,
  984. input_ids,
  985. past_key_values=None,
  986. attention_mask=None,
  987. inputs_embeds=None,
  988. position_ids=None,
  989. use_cache=True,
  990. pixel_values=None,
  991. pixel_values_videos=None,
  992. image_grid_thw=None,
  993. video_grid_thw=None,
  994. is_first_iteration=False,
  995. **kwargs,
  996. ):
  997. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  998. model_inputs = super().prepare_inputs_for_generation(
  999. input_ids,
  1000. past_key_values=past_key_values,
  1001. attention_mask=attention_mask,
  1002. inputs_embeds=inputs_embeds,
  1003. position_ids=position_ids,
  1004. pixel_values=pixel_values,
  1005. pixel_values_videos=pixel_values_videos,
  1006. image_grid_thw=image_grid_thw,
  1007. video_grid_thw=video_grid_thw,
  1008. use_cache=use_cache,
  1009. is_first_iteration=is_first_iteration,
  1010. **kwargs,
  1011. )
  1012. if not is_first_iteration and use_cache:
  1013. model_inputs["pixel_values"] = None
  1014. model_inputs["pixel_values_videos"] = None
  1015. return model_inputs
  1016. def _get_image_nums_and_video_nums(
  1017. self,
  1018. input_ids: torch.LongTensor | None,
  1019. inputs_embeds: torch.Tensor | None = None,
  1020. ) -> tuple[torch.Tensor, torch.Tensor]:
  1021. """
  1022. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
  1023. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
  1024. Args:
  1025. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1026. Indices of input sequence tokens in the vocabulary.
  1027. Returns:
  1028. image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
  1029. video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
  1030. """
  1031. if inputs_embeds is not None:
  1032. is_image = (
  1033. inputs_embeds
  1034. == self.get_input_embeddings()(
  1035. torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
  1036. )
  1037. )[..., 0]
  1038. is_video_start = (
  1039. inputs_embeds
  1040. == self.get_input_embeddings()(
  1041. torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
  1042. )
  1043. )[..., 0]
  1044. is_video_end = (
  1045. inputs_embeds
  1046. == self.get_input_embeddings()(
  1047. torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
  1048. )
  1049. )[..., 0]
  1050. else:
  1051. is_image = input_ids == self.config.image_start_token_id
  1052. is_video_start = input_ids == self.config.video_start_token_id
  1053. is_video_end = input_ids == self.config.video_end_token_id
  1054. # Cumulative sum to track if we're inside a video span
  1055. # We'll assume well-formed video tags (i.e. matching starts and ends)
  1056. video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
  1057. inside_video = video_level > 0 # shape (batch_size, seq_length)
  1058. # Mask out image tokens that are inside video spans
  1059. standalone_images = is_image & (~inside_video)
  1060. # Count per batch
  1061. image_counts = standalone_images.sum(dim=1)
  1062. video_counts = is_video_start.sum(dim=1)
  1063. return image_counts, video_counts
  1064. class Glm4vProcessorKwargs(Qwen2VLProcessorKwargs):
  1065. _defaults = {
  1066. "text_kwargs": {
  1067. "padding": False,
  1068. "return_token_type_ids": False,
  1069. "return_mm_token_type_ids": True,
  1070. },
  1071. "videos_kwargs": {"return_metadata": True},
  1072. }
  1073. class Glm4vProcessor(Qwen2VLProcessor):
  1074. def __init__(self, image_processor=None, tokenizer=None, video_processor=None, chat_template=None, **kwargs):
  1075. super().__init__(image_processor, tokenizer, video_processor, chat_template=chat_template)
  1076. self.image_token = "<|image|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token
  1077. self.video_token = "<|video|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token
  1078. self.video_start_id = tokenizer.convert_tokens_to_ids("<|begin_of_video|>")
  1079. self.video_end_id = tokenizer.convert_tokens_to_ids("<|end_of_video|>")
  1080. def __call__(
  1081. self,
  1082. images: ImageInput | None = None,
  1083. text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput] = None,
  1084. videos: VideoInput | None = None,
  1085. **kwargs: Unpack[Glm4vProcessorKwargs],
  1086. ) -> BatchFeature:
  1087. r"""
  1088. Returns:
  1089. [`BatchFeature`]: A [`BatchFeature`] with the following fields:
  1090. - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
  1091. - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
  1092. `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
  1093. `None`).
  1094. - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
  1095. - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`.
  1096. - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
  1097. - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`.
  1098. """
  1099. output_kwargs = self._merge_kwargs(
  1100. Glm4vProcessorKwargs,
  1101. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  1102. **kwargs,
  1103. )
  1104. if images is not None:
  1105. image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
  1106. image_grid_thw = image_inputs["image_grid_thw"]
  1107. else:
  1108. image_inputs = {}
  1109. image_grid_thw = None
  1110. if videos is not None:
  1111. videos_inputs = self.video_processor(videos=videos, **output_kwargs["videos_kwargs"])
  1112. # If user has not requested video metadata, pop it
  1113. if not kwargs.get("return_metadata"):
  1114. video_metadata = videos_inputs.pop("video_metadata")
  1115. else:
  1116. video_metadata = videos_inputs["video_metadata"]
  1117. video_grid_thw = videos_inputs["video_grid_thw"]
  1118. else:
  1119. videos_inputs = {}
  1120. video_grid_thw = None
  1121. if not isinstance(text, list):
  1122. text = [text]
  1123. text = text.copy() # below lines change text in-place
  1124. if image_grid_thw is not None:
  1125. merge_length = self.image_processor.merge_size**2
  1126. index = 0
  1127. for i in range(len(text)):
  1128. while self.image_token in text[i]:
  1129. num_image_tokens = image_grid_thw[index].prod() // merge_length
  1130. text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
  1131. index += 1
  1132. text[i] = text[i].replace("<|placeholder|>", self.image_token)
  1133. if video_grid_thw is not None:
  1134. merge_length = self.video_processor.merge_size**2
  1135. video_index = 0
  1136. for i in range(len(text)):
  1137. while self.video_token in text[i]:
  1138. num_frames = video_grid_thw[video_index][0]
  1139. video_structure = ""
  1140. metadata = video_metadata[video_index]
  1141. if metadata.fps is None:
  1142. logger.warning_once(
  1143. "SmolVLM requires frame timestamps to construct prompts, but the `fps` of the input video could not be inferred. "
  1144. "Probably `video_metadata` was missing from inputs and you passed pre-sampled frames. "
  1145. "Defaulting to `fps=24`. Please provide `video_metadata` for more accurate results."
  1146. )
  1147. metadata.fps = 24 if metadata.fps is None else metadata.fps
  1148. timestamps = metadata.timestamps[::2] # mrope
  1149. unique_timestamps = []
  1150. for idx in range(0, len(timestamps)):
  1151. unique_timestamps.append(timestamps[idx])
  1152. selected_timestamps = unique_timestamps[:num_frames]
  1153. while len(selected_timestamps) < num_frames:
  1154. selected_timestamps.append(selected_timestamps[-1] if selected_timestamps else 0)
  1155. for frame_idx in range(num_frames):
  1156. timestamp_sec = selected_timestamps[frame_idx]
  1157. frame_structure = self.replace_frame_token_id(timestamp_sec)
  1158. video_structure += frame_structure
  1159. text[i] = text[i].replace(self.video_token, video_structure, 1)
  1160. num_image_tokens = (
  1161. video_grid_thw[video_index].prod() // merge_length // video_grid_thw[video_index][0]
  1162. )
  1163. for frame_idx in range(num_frames):
  1164. if self.image_token in text[i]:
  1165. text[i] = text[i].replace(self.image_token, "<|placeholder|>" * num_image_tokens, 1)
  1166. video_index += 1
  1167. text[i] = text[i].replace("<|placeholder|>", self.image_token)
  1168. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", None)
  1169. return_mm_token_type_ids = output_kwargs["text_kwargs"].pop("return_mm_token_type_ids", False)
  1170. text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
  1171. self._check_special_mm_tokens(text, text_inputs, modalities=["image", "video"])
  1172. if return_mm_token_type_ids:
  1173. text_inputs["mm_token_type_ids"] = self.create_mm_token_type_ids(text_inputs["input_ids"])
  1174. return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}, tensor_type=return_tensors)
  1175. def create_mm_token_type_ids(self, input_ids: list) -> list[list[int]]:
  1176. # We have to iterate for each list separately because inputs
  1177. # might be non-padded lists and we can't cast numpy on that!
  1178. # Then cast numpy as each input for faster indexing
  1179. mm_token_type_ids = []
  1180. for input in input_ids:
  1181. array_ids = np.array(input)
  1182. mm_token_types = np.zeros_like(input)
  1183. # Replace 0 -> 2 only inside video segments because GLM4v
  1184. # uses the same special token to denote images and video
  1185. # Otherwise replace 0 -> 1 for image modality
  1186. starts = np.cumsum(array_ids == self.video_start_id, axis=0)
  1187. ends = np.cumsum(array_ids == self.video_end_id, axis=0)
  1188. is_video_modality = starts > ends
  1189. mm_token_types[(array_ids == self.image_token_id) & is_video_modality] = 2
  1190. mm_token_types[(array_ids == self.image_token_id) & (~is_video_modality)] = 1
  1191. mm_token_type_ids.append(mm_token_types.tolist())
  1192. return mm_token_type_ids
  1193. def replace_frame_token_id(self, timestamp_sec):
  1194. return f"<|begin_of_image|>{self.image_token}<|end_of_image|>{int(timestamp_sec)}"
  1195. __all__ = [
  1196. "Glm4vConfig",
  1197. "Glm4vTextConfig",
  1198. "Glm4vVisionConfig",
  1199. "Glm4vForConditionalGeneration",
  1200. "Glm4vModel",
  1201. "Glm4vPreTrainedModel",
  1202. "Glm4vProcessor",
  1203. "Glm4vTextModel",
  1204. "Glm4vVisionModel",
  1205. ]