modeling_glm46v.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/glm46v/modular_glm46v.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_glm46v.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 the HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. import itertools
  21. from dataclasses import dataclass
  22. from typing import Any
  23. import torch
  24. import torch.nn as nn
  25. from ...cache_utils import Cache
  26. from ...generation import GenerationMixin
  27. from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput
  28. from ...modeling_utils import PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import (
  31. TransformersKwargs,
  32. auto_docstring,
  33. can_return_tuple,
  34. torch_compilable_check,
  35. )
  36. from ..auto import AutoModel
  37. from .configuration_glm46v import Glm46VConfig
  38. @auto_docstring
  39. class Glm46VPreTrainedModel(PreTrainedModel):
  40. config: Glm46VConfig
  41. base_model_prefix = "model"
  42. input_modalities = ("image", "video", "text")
  43. supports_gradient_checkpointing = True
  44. _no_split_modules = None
  45. _skip_keys_device_placement = "past_key_values"
  46. _supports_flash_attn = True
  47. _supports_sdpa = True
  48. _can_compile_fullgraph = True
  49. _supports_attention_backend = True
  50. _can_record_outputs = None
  51. @dataclass
  52. @auto_docstring(
  53. custom_intro="""
  54. Base class for Llava outputs, with hidden states and attentions.
  55. """
  56. )
  57. class Glm46VModelOutputWithPast(ModelOutput):
  58. r"""
  59. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  60. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  61. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  62. `past_key_values` input) to speed up sequential decoding.
  63. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
  64. The rope index difference between sequence length and multimodal rope.
  65. """
  66. last_hidden_state: torch.FloatTensor | None = None
  67. past_key_values: Cache | None = None
  68. hidden_states: tuple[torch.FloatTensor] | None = None
  69. attentions: tuple[torch.FloatTensor] | None = None
  70. rope_deltas: torch.LongTensor | None = None
  71. @auto_docstring
  72. class Glm46VModel(Glm46VPreTrainedModel):
  73. base_model_prefix = "model"
  74. # Reference: fix gemma3 grad acc #37208
  75. accepts_loss_kwargs = False
  76. _no_split_modules = None
  77. def __init__(self, config):
  78. super().__init__(config)
  79. self.visual = AutoModel.from_config(config.vision_config)
  80. self.language_model = AutoModel.from_config(config.text_config)
  81. self.rope_deltas = None # cache rope_deltas here
  82. # Initialize weights and apply final processing
  83. self.post_init()
  84. def get_input_embeddings(self):
  85. return self.language_model.get_input_embeddings()
  86. def set_input_embeddings(self, value):
  87. self.language_model.set_input_embeddings(value)
  88. def get_vision_position_ids(
  89. self,
  90. start_position: int,
  91. grid_thw: list[int, int, int] | torch.Tensor,
  92. temp_merge_size: int = 1,
  93. spatial_merge_size: int = 1,
  94. time_interval: int = 1,
  95. device: str | torch.device | None = None,
  96. ):
  97. """
  98. Compute 3D positional indices for vision tokens derived from a single image or video input.
  99. The positions are generated from the input grid defined by temporal (T), height (H), and
  100. width (W) dimensions. Temporal and spatial dimensions can be downscaled according to the
  101. merge sizes used in the vision backbone. The resulting positions are offset by `start_position`.
  102. Args:
  103. start_position (`int`):
  104. Offset added to all computed positional indices.
  105. grid_thw (`Sequence[int]` or `torch.Tensor` of shape `(3,)`):
  106. The (T, H, W) grid representing the feature layout of the current image or video after patch embedding.
  107. temp_merge_size (`int`, *optional*):
  108. Factor by which the temporal dimension is reduced in the backbone. The temporal grid size is divided
  109. by this value. Defaults to 1.
  110. spatial_merge_size (`int`, *optional*):
  111. Factor by which the spatial dimensions (H and W) are reduced in the backbone. Both H and W are divided
  112. by this value. Defaults to 1.
  113. time_interval (`int`, *optional*):
  114. Spacing factor applied between consecutive temporal position indices.Defaults to 1.
  115. device (`str` or `torch.device`, *optional*):
  116. Device on which the resulting tensor is allocated. If `None`, uses the current default device.
  117. Returns:
  118. torch.LongTensor of shape (3, sequence_length):
  119. Positional indices for temporal, height, and width dimensions,
  120. flattened into sequence form and offset by `start_position`.
  121. """
  122. llm_grid_t, llm_grid_h, llm_grid_w = (
  123. grid_thw[0].item() // temp_merge_size,
  124. grid_thw[1].item() // spatial_merge_size,
  125. grid_thw[2].item() // spatial_merge_size,
  126. )
  127. image_seq_length = llm_grid_h * llm_grid_w * llm_grid_t
  128. position_width = torch.arange(start_position, start_position + llm_grid_w, device=device).repeat(
  129. llm_grid_h * llm_grid_t
  130. )
  131. position_height = torch.arange(start_position, start_position + llm_grid_h, device=device).repeat_interleave(
  132. llm_grid_w * llm_grid_t
  133. )
  134. position_temporal = torch.full((image_seq_length,), start_position, device=device, dtype=torch.long)
  135. position_temporal = position_temporal * time_interval
  136. vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
  137. return vision_position_ids
  138. def get_rope_index(
  139. self,
  140. input_ids: torch.LongTensor,
  141. mm_token_type_ids: torch.IntTensor,
  142. image_grid_thw: torch.LongTensor | None = None,
  143. video_grid_thw: torch.LongTensor | None = None,
  144. attention_mask: torch.Tensor | None = None,
  145. **kwargs,
  146. ) -> tuple[torch.Tensor, torch.Tensor]:
  147. """
  148. Calculate the 3D rope index based on image and video's sizes. The utility expects a `vision + text`
  149. sequence and will error out otherwise. For pure text sequence, please rely on model's auto-inferred
  150. position ids. In a mixed vision + text sequence, vision tokens use 3D RoPE (temporal, height, width)
  151. while text tokens use standard 1D RoPE.
  152. Example:
  153. Temporal patches: 3; Height patches: 2; Width patches: 2
  154. Each vision input results in (temporal x height × width) positions. Here: 3 x 2 × 2 = 12 positions total.
  155. Temporal position IDs are spaced by:
  156. `interval = tokens_per_second * temporal_patch_size / fps`
  157. If fps = 1; tokens_per_second = 25; temporal_patch_size = 2, temporal IDs increase by 50 for each temporal patch:
  158. `[0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]`
  159. Height IDs repeat per row: `[0, 0, 1, 1, ...]`
  160. Width IDs alternate per column: `[0, 1, 0, 1, ...]`
  161. Text tokens follow standard 1D RoPE and the position IDs grow consequently with a step of `1`
  162. Args:
  163. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  164. Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
  165. it.
  166. mm_token_type_ids (`torch.IntTensor` of shape `(batch_size, sequence_length)`):
  167. Token type ids matching each modality to a different value in the input sequence, i.e. text (0), image (1), video (2).
  168. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  169. The temporal, height and width of feature shape of each image in LLM.
  170. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  171. The temporal, height and width of feature shape of each video in LLM.
  172. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  173. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  174. - 1 for tokens that are **not masked**,
  175. - 0 for tokens that are **masked**.
  176. Returns:
  177. position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
  178. mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
  179. """
  180. spatial_merge_size = self.config.vision_config.spatial_merge_size
  181. mrope_position_deltas = []
  182. position_ids = torch.zeros(
  183. 3,
  184. input_ids.shape[0],
  185. input_ids.shape[1],
  186. dtype=input_ids.dtype,
  187. device=input_ids.device,
  188. )
  189. grid_iters = {
  190. 1: iter(image_grid_thw) if image_grid_thw is not None else None,
  191. 2: iter(video_grid_thw) if video_grid_thw is not None else None,
  192. }
  193. for batch_idx, current_input_ids in enumerate(input_ids):
  194. input_token_type = mm_token_type_ids[batch_idx]
  195. if attention_mask is not None:
  196. current_input_ids = current_input_ids[attention_mask[batch_idx].bool()]
  197. input_token_type = input_token_type[attention_mask[batch_idx].bool()]
  198. input_type_group = []
  199. for key, group in itertools.groupby(enumerate(input_token_type.tolist()), lambda x: x[1]):
  200. group = list(group)
  201. start_index = group[0][0]
  202. end_index = group[-1][0] + 1
  203. input_type_group.append((key, start_index, end_index))
  204. current_pos = 0
  205. video_group_index = 0
  206. llm_pos_ids_list = []
  207. for modality_type, start_idx, end_idx in input_type_group:
  208. # text == 0
  209. if modality_type == 0:
  210. text_len = end_idx - start_idx
  211. llm_pos_ids_list.append(
  212. torch.arange(text_len, device=input_ids.device).view(1, -1).expand(3, -1) + current_pos
  213. )
  214. current_pos += text_len
  215. # image == 1, video == 2
  216. else:
  217. # GLM46V splits video into segments per frame but there's only one `grid_thw`
  218. # per whole video. We can't exhaus the iterator and have to re-use the grid
  219. # while processing the same video!
  220. if modality_type == 2:
  221. if video_group_index == 0:
  222. grid_thw = next(grid_iters[modality_type])
  223. video_group_index += 1
  224. video_group_index = 0 if video_group_index >= grid_thw[0] else video_group_index
  225. else:
  226. grid_thw = next(grid_iters[modality_type])
  227. # Videos are processed per frame separately, each temporal grid is always `1`
  228. temp_merge_size = grid_thw[0]
  229. vision_position_ids = self.get_vision_position_ids(
  230. current_pos, grid_thw, temp_merge_size, spatial_merge_size, device=input_ids.device
  231. )
  232. llm_pos_ids_list.append(vision_position_ids)
  233. current_pos += max(grid_thw[1], grid_thw[2]) // spatial_merge_size
  234. llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
  235. if attention_mask is not None:
  236. position_ids[:, batch_idx, attention_mask[batch_idx].bool()] = llm_positions.to(position_ids.device)
  237. else:
  238. position_ids[:, batch_idx] = llm_positions.to(position_ids.device)
  239. mrope_position_deltas.append(llm_positions.max() + 1 - len(current_input_ids))
  240. mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
  241. return position_ids, mrope_position_deltas
  242. @can_return_tuple
  243. @auto_docstring
  244. def get_video_features(
  245. self,
  246. pixel_values_videos: torch.FloatTensor,
  247. video_grid_thw: torch.LongTensor | None = None,
  248. **kwargs: Unpack[TransformersKwargs],
  249. ) -> tuple | BaseModelOutputWithPooling:
  250. r"""
  251. pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  252. The tensors corresponding to the input videos.
  253. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  254. The temporal, height and width of feature shape of each video in LLM.
  255. """
  256. pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
  257. # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
  258. temp_frames_hw = []
  259. video_grid_thw_list = video_grid_thw.tolist()
  260. for t, h, w in video_grid_thw_list:
  261. repeated_row = torch.tensor([1, h, w]).unsqueeze(0).repeat(t, 1)
  262. temp_frames_hw.append(repeated_row)
  263. flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
  264. vision_outputs = self.visual(
  265. pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs
  266. )
  267. split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
  268. video_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
  269. vision_outputs.pooler_output = video_embeds
  270. return vision_outputs
  271. @can_return_tuple
  272. @auto_docstring
  273. def get_image_features(
  274. self,
  275. pixel_values: torch.FloatTensor,
  276. image_grid_thw: torch.LongTensor | None = None,
  277. **kwargs: Unpack[TransformersKwargs],
  278. ) -> tuple | BaseModelOutputWithPooling:
  279. r"""
  280. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  281. The tensors corresponding to the input images.
  282. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  283. The temporal, height and width of feature shape of each image in LLM.
  284. """
  285. pixel_values = pixel_values.type(self.visual.dtype)
  286. vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs)
  287. split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
  288. image_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
  289. vision_outputs.pooler_output = image_embeds
  290. return vision_outputs
  291. def get_placeholder_mask(
  292. self,
  293. input_ids: torch.LongTensor,
  294. inputs_embeds: torch.FloatTensor,
  295. image_features: torch.FloatTensor | None = None,
  296. video_features: torch.FloatTensor | None = None,
  297. ):
  298. """
  299. Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
  300. equal to the length of multimodal features. If the lengths are different, an error is raised.
  301. """
  302. if input_ids is None:
  303. special_image_mask = inputs_embeds == self.get_input_embeddings()(
  304. torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
  305. )
  306. special_image_mask = special_image_mask.all(-1)
  307. special_video_mask = inputs_embeds == self.get_input_embeddings()(
  308. torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
  309. )
  310. special_video_mask = special_video_mask.all(-1)
  311. else:
  312. # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask
  313. special_image_mask = input_ids == self.config.image_token_id
  314. special_video_mask = input_ids == self.config.image_token_id
  315. n_image_tokens = special_image_mask.sum()
  316. special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  317. if image_features is not None:
  318. torch_compilable_check(
  319. inputs_embeds[special_image_mask].numel() == image_features.numel(),
  320. f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
  321. )
  322. n_video_tokens = special_video_mask.sum()
  323. special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
  324. if video_features is not None:
  325. torch_compilable_check(
  326. inputs_embeds[special_video_mask].numel() == video_features.numel(),
  327. f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
  328. )
  329. return special_image_mask, special_video_mask
  330. def compute_3d_position_ids(
  331. self,
  332. input_ids: torch.Tensor | None,
  333. inputs_embeds: torch.Tensor | None,
  334. image_grid_thw: torch.Tensor | None = None,
  335. video_grid_thw: torch.Tensor | None = None,
  336. attention_mask: torch.Tensor | None = None,
  337. past_key_values: torch.Tensor | None = None,
  338. mm_token_type_ids: torch.IntTensor | None = None,
  339. ) -> torch.Tensor | None:
  340. past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
  341. has_multimodal = image_grid_thw is not None or video_grid_thw is not None
  342. if has_multimodal and mm_token_type_ids is None and input_ids is not None:
  343. raise ValueError(
  344. "Multimodal data was passed (via `image_grid_thw` or `video_grid_thw`) but `mm_token_type_ids` is "
  345. "missing. Please pass `mm_token_type_ids` to the model so that multimodal RoPE (M-RoPE) can be "
  346. "computed correctly. `mm_token_type_ids` is returned by the processor alongside `input_ids`."
  347. )
  348. can_compute_mrope = input_ids is not None and mm_token_type_ids is not None and has_multimodal
  349. if can_compute_mrope and (self.rope_deltas is None or past_key_values_length == 0):
  350. position_ids, rope_deltas = self.get_rope_index(
  351. input_ids,
  352. image_grid_thw=image_grid_thw,
  353. video_grid_thw=video_grid_thw,
  354. attention_mask=attention_mask,
  355. mm_token_type_ids=mm_token_type_ids,
  356. )
  357. self.rope_deltas = rope_deltas
  358. # Use pre-calculated rope-deltas to infer correct 3D position ids during incremental
  359. # generation (past_key_values_length > 0) or when only inputs_embeds is provided (no input_ids
  360. # to recompute from). Skip when input_ids is provided without past_key_values to avoid shape
  361. # mismatches from stale rope_deltas (e.g., training forward pass after generation).
  362. elif self.rope_deltas is not None and (past_key_values_length > 0 or input_ids is None):
  363. batch_size, seq_length, _ = inputs_embeds.shape
  364. if attention_mask is not None:
  365. position_ids = attention_mask.long().cumsum(-1) - 1
  366. position_ids = position_ids.masked_fill(attention_mask == 0, 0)
  367. position_ids = position_ids.view(1, batch_size, -1).repeat(3, 1, 1).to(inputs_embeds.device)
  368. else:
  369. position_ids = torch.arange(past_key_values_length, past_key_values_length + seq_length)
  370. position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1).to(inputs_embeds.device)
  371. delta = self.rope_deltas.repeat_interleave(batch_size // self.rope_deltas.shape[0], dim=0)
  372. position_ids = position_ids + delta.to(device=inputs_embeds.device)
  373. else:
  374. # Can't build correct 3D positions. Let the model infer it
  375. position_ids = None
  376. return position_ids
  377. @auto_docstring
  378. @can_return_tuple
  379. def forward(
  380. self,
  381. input_ids: torch.LongTensor | None = None,
  382. attention_mask: torch.Tensor | None = None,
  383. position_ids: torch.LongTensor | None = None,
  384. past_key_values: Cache | None = None,
  385. inputs_embeds: torch.FloatTensor | None = None,
  386. pixel_values: torch.Tensor | None = None,
  387. pixel_values_videos: torch.FloatTensor | None = None,
  388. image_grid_thw: torch.LongTensor | None = None,
  389. video_grid_thw: torch.LongTensor | None = None,
  390. rope_deltas: torch.LongTensor | None = None,
  391. mm_token_type_ids: torch.IntTensor | None = None,
  392. **kwargs: Unpack[TransformersKwargs],
  393. ) -> tuple | Glm46VModelOutputWithPast:
  394. r"""
  395. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  396. The temporal, height and width of feature shape of each image in LLM.
  397. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  398. The temporal, height and width of feature shape of each video in LLM.
  399. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
  400. The rope index difference between sequence length and multimodal rope.
  401. """
  402. if (input_ids is None) ^ (inputs_embeds is not None):
  403. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  404. if inputs_embeds is None:
  405. inputs_embeds = self.get_input_embeddings()(input_ids)
  406. if pixel_values is not None:
  407. image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
  408. image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  409. image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
  410. inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
  411. if pixel_values_videos is not None:
  412. video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
  413. video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
  414. _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
  415. inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
  416. if position_ids is None:
  417. position_ids = self.compute_3d_position_ids(
  418. input_ids=input_ids,
  419. image_grid_thw=image_grid_thw,
  420. video_grid_thw=video_grid_thw,
  421. inputs_embeds=inputs_embeds,
  422. attention_mask=attention_mask,
  423. past_key_values=past_key_values,
  424. mm_token_type_ids=mm_token_type_ids,
  425. )
  426. outputs = self.language_model(
  427. input_ids=None,
  428. position_ids=position_ids,
  429. attention_mask=attention_mask,
  430. past_key_values=past_key_values,
  431. inputs_embeds=inputs_embeds,
  432. **kwargs,
  433. )
  434. return Glm46VModelOutputWithPast(
  435. **outputs,
  436. rope_deltas=self.rope_deltas,
  437. )
  438. @dataclass
  439. @auto_docstring(
  440. custom_intro="""
  441. Base class for Glm46V causal language model (or autoregressive) outputs.
  442. """
  443. )
  444. class Glm46VCausalLMOutputWithPast(ModelOutput):
  445. r"""
  446. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  447. Language modeling loss (for next-token prediction).
  448. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  449. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  450. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  451. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  452. Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
  453. `past_key_values` input) to speed up sequential decoding.
  454. rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
  455. The rope index difference between sequence length and multimodal rope.
  456. """
  457. loss: torch.FloatTensor | None = None
  458. logits: torch.FloatTensor | None = None
  459. past_key_values: Cache | None = None
  460. hidden_states: tuple[torch.FloatTensor] | None = None
  461. attentions: tuple[torch.FloatTensor] | None = None
  462. rope_deltas: torch.LongTensor | None = None
  463. class Glm46VForConditionalGeneration(Glm46VPreTrainedModel, GenerationMixin):
  464. _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
  465. # Reference: fix gemma3 grad acc #37208
  466. accepts_loss_kwargs = False
  467. def __init__(self, config):
  468. super().__init__(config)
  469. self.model = Glm46VModel(config)
  470. self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
  471. self.post_init()
  472. def get_input_embeddings(self):
  473. return self.model.get_input_embeddings()
  474. def set_input_embeddings(self, value):
  475. self.model.set_input_embeddings(value)
  476. @auto_docstring
  477. def get_video_features(
  478. self,
  479. pixel_values_videos: torch.FloatTensor,
  480. video_grid_thw: torch.LongTensor | None = None,
  481. **kwargs: Unpack[TransformersKwargs],
  482. ) -> tuple | BaseModelOutputWithPooling:
  483. r"""
  484. pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  485. The tensors corresponding to the input videos.
  486. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  487. The temporal, height and width of feature shape of each video in LLM.
  488. """
  489. return self.model.get_video_features(
  490. pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
  491. )
  492. @auto_docstring
  493. def get_image_features(
  494. self,
  495. pixel_values: torch.FloatTensor,
  496. image_grid_thw: torch.LongTensor | None = None,
  497. **kwargs: Unpack[TransformersKwargs],
  498. ) -> tuple | BaseModelOutputWithPooling:
  499. r"""
  500. pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
  501. The tensors corresponding to the input images.
  502. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  503. The temporal, height and width of feature shape of each image in LLM.
  504. """
  505. return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
  506. @can_return_tuple
  507. @auto_docstring
  508. def forward(
  509. self,
  510. input_ids: torch.LongTensor | None = None,
  511. attention_mask: torch.Tensor | None = None,
  512. position_ids: torch.LongTensor | None = None,
  513. past_key_values: Cache | None = None,
  514. inputs_embeds: torch.FloatTensor | None = None,
  515. labels: torch.LongTensor | None = None,
  516. pixel_values: torch.Tensor | None = None,
  517. pixel_values_videos: torch.FloatTensor | None = None,
  518. image_grid_thw: torch.LongTensor | None = None,
  519. video_grid_thw: torch.LongTensor | None = None,
  520. mm_token_type_ids: torch.IntTensor | None = None,
  521. logits_to_keep: int | torch.Tensor = 0,
  522. **kwargs: Unpack[TransformersKwargs],
  523. ) -> tuple | Glm46VCausalLMOutputWithPast:
  524. r"""
  525. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  526. Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
  527. config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
  528. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  529. image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
  530. The temporal, height and width of feature shape of each image in LLM.
  531. video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
  532. The temporal, height and width of feature shape of each video in LLM.
  533. Example:
  534. ```python
  535. >>> from PIL import Image
  536. >>> import httpx
  537. >>> from io import BytesIO
  538. >>> from transformers import AutoProcessor, Glm46VForConditionalGeneration
  539. >>> model = Glm46VForConditionalGeneration.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
  540. >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
  541. >>> messages = [
  542. {
  543. "role": "user",
  544. "content": [
  545. {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
  546. {"type": "text", "text": "What is shown in this image?"},
  547. ],
  548. },
  549. ]
  550. >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
  551. >>> with httpx.stream("GET", url) as response:
  552. ... image = Image.open(BytesIO(response.read()))
  553. >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
  554. >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
  555. >>> # Generate
  556. >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
  557. >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
  558. "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
  559. ```"""
  560. outputs = self.model(
  561. input_ids=input_ids,
  562. pixel_values=pixel_values,
  563. pixel_values_videos=pixel_values_videos,
  564. image_grid_thw=image_grid_thw,
  565. video_grid_thw=video_grid_thw,
  566. mm_token_type_ids=mm_token_type_ids,
  567. position_ids=position_ids,
  568. attention_mask=attention_mask,
  569. past_key_values=past_key_values,
  570. inputs_embeds=inputs_embeds,
  571. **kwargs,
  572. )
  573. hidden_states = outputs[0]
  574. # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
  575. slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
  576. logits = self.lm_head(hidden_states[:, slice_indices, :])
  577. loss = None
  578. if labels is not None:
  579. loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
  580. return Glm46VCausalLMOutputWithPast(
  581. loss=loss,
  582. logits=logits,
  583. past_key_values=outputs.past_key_values,
  584. hidden_states=outputs.hidden_states,
  585. attentions=outputs.attentions,
  586. rope_deltas=outputs.rope_deltas,
  587. )
  588. def prepare_inputs_for_generation(
  589. self,
  590. input_ids,
  591. past_key_values=None,
  592. attention_mask=None,
  593. inputs_embeds=None,
  594. position_ids=None,
  595. use_cache=True,
  596. pixel_values=None,
  597. pixel_values_videos=None,
  598. image_grid_thw=None,
  599. video_grid_thw=None,
  600. is_first_iteration=False,
  601. **kwargs,
  602. ):
  603. # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
  604. model_inputs = super().prepare_inputs_for_generation(
  605. input_ids,
  606. past_key_values=past_key_values,
  607. attention_mask=attention_mask,
  608. inputs_embeds=inputs_embeds,
  609. position_ids=position_ids,
  610. pixel_values=pixel_values,
  611. pixel_values_videos=pixel_values_videos,
  612. image_grid_thw=image_grid_thw,
  613. video_grid_thw=video_grid_thw,
  614. use_cache=use_cache,
  615. is_first_iteration=is_first_iteration,
  616. **kwargs,
  617. )
  618. if not is_first_iteration and use_cache:
  619. model_inputs["pixel_values"] = None
  620. model_inputs["pixel_values_videos"] = None
  621. return model_inputs
  622. def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs):
  623. # Overwritten -- requires 3D position ids
  624. text_positions = super()._prepare_position_ids_for_generation(inputs_tensor, model_kwargs)
  625. # Early exit in case we are continuing generation from past kv
  626. past_length = 0
  627. if (cache := model_kwargs.get("past_key_values")) is not None:
  628. past_length = cache.get_seq_length()
  629. if past_length != 0 and self.model.rope_deltas is not None:
  630. position_ids = text_positions[None, ...] + self.model.rope_deltas
  631. return position_ids
  632. # Otherwise compute 3d position ids for vision tokens and concat with text position ids
  633. if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0:
  634. inputs_tensor = model_kwargs["input_ids"]
  635. is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long]
  636. if (
  637. is_input_ids
  638. and model_kwargs.get("mm_token_type_ids") is not None
  639. and (model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None)
  640. ):
  641. model_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"}
  642. vision_positions, rope_deltas = self.model.get_rope_index(inputs_tensor, **model_kwargs)
  643. self.model.rope_deltas = rope_deltas
  644. else:
  645. vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1)
  646. self.model.rope_deltas = torch.zeros(
  647. inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device
  648. )
  649. # Concatenate "text + vision" positions into [4, bs, seq-len]
  650. text_positions = text_positions[None, ...]
  651. position_ids = torch.cat([text_positions, vision_positions], dim=0)
  652. return position_ids
  653. def _get_image_nums_and_video_nums(
  654. self,
  655. input_ids: torch.LongTensor | None,
  656. inputs_embeds: torch.Tensor | None = None,
  657. ) -> tuple[torch.Tensor, torch.Tensor]:
  658. """
  659. Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
  660. These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
  661. Args:
  662. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  663. Indices of input sequence tokens in the vocabulary.
  664. Returns:
  665. image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
  666. video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
  667. """
  668. if inputs_embeds is not None:
  669. is_image = (
  670. inputs_embeds
  671. == self.get_input_embeddings()(
  672. torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
  673. )
  674. )[..., 0]
  675. is_video_start = (
  676. inputs_embeds
  677. == self.get_input_embeddings()(
  678. torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
  679. )
  680. )[..., 0]
  681. is_video_end = (
  682. inputs_embeds
  683. == self.get_input_embeddings()(
  684. torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
  685. )
  686. )[..., 0]
  687. else:
  688. is_image = input_ids == self.config.image_start_token_id
  689. is_video_start = input_ids == self.config.video_start_token_id
  690. is_video_end = input_ids == self.config.video_end_token_id
  691. # Cumulative sum to track if we're inside a video span
  692. # We'll assume well-formed video tags (i.e. matching starts and ends)
  693. video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
  694. inside_video = video_level > 0 # shape (batch_size, seq_length)
  695. # Mask out image tokens that are inside video spans
  696. standalone_images = is_image & (~inside_video)
  697. # Count per batch
  698. image_counts = standalone_images.sum(dim=1)
  699. video_counts = is_video_start.sum(dim=1)
  700. return image_counts, video_counts
  701. def _expand_inputs_for_generation(
  702. self,
  703. expand_size: int = 1,
  704. is_encoder_decoder: bool = False,
  705. input_ids: torch.LongTensor | None = None,
  706. **model_kwargs,
  707. ) -> tuple[torch.LongTensor, dict[str, Any]]:
  708. # Overwritten -- Support for expanding tensors without a batch size dimension
  709. # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
  710. # pixel_values.shape[0] is sum(seqlen_images for samples)
  711. # image_grid_thw.shape[0] is sum(num_images for samples)
  712. if expand_size == 1:
  713. return input_ids, model_kwargs
  714. visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
  715. def _expand_dict_for_generation_visual(dict_to_expand):
  716. image_grid_thw = model_kwargs.get("image_grid_thw", None)
  717. video_grid_thw = model_kwargs.get("video_grid_thw", None)
  718. image_nums, video_nums = self._get_image_nums_and_video_nums(
  719. input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
  720. )
  721. def _repeat_interleave_samples(x, lengths, repeat_times):
  722. samples = torch.split(x, lengths)
  723. repeat_args = [repeat_times] + [1] * (x.dim() - 1)
  724. result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
  725. return result
  726. for key in dict_to_expand:
  727. if key == "pixel_values":
  728. # split images into samples
  729. samples = torch.split(image_grid_thw, list(image_nums))
  730. # compute the sequence length of images for each sample
  731. lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
  732. dict_to_expand[key] = _repeat_interleave_samples(
  733. dict_to_expand[key], lengths=lengths, repeat_times=expand_size
  734. )
  735. elif key == "image_grid_thw":
  736. # get the num of images for each sample
  737. lengths = list(image_nums)
  738. dict_to_expand[key] = _repeat_interleave_samples(
  739. dict_to_expand[key], lengths=lengths, repeat_times=expand_size
  740. )
  741. elif key == "pixel_values_videos":
  742. samples = torch.split(video_grid_thw, list(video_nums))
  743. lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
  744. dict_to_expand[key] = _repeat_interleave_samples(
  745. dict_to_expand[key], lengths=lengths, repeat_times=expand_size
  746. )
  747. elif key == "video_grid_thw":
  748. lengths = list(video_nums)
  749. dict_to_expand[key] = _repeat_interleave_samples(
  750. dict_to_expand[key], lengths=lengths, repeat_times=expand_size
  751. )
  752. elif key == "second_per_grid_ts":
  753. dict_to_expand[key] = _repeat_interleave_samples(
  754. dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
  755. )
  756. return dict_to_expand
  757. def _expand_dict_for_generation(dict_to_expand):
  758. for key in dict_to_expand:
  759. if key == "position_ids" and dict_to_expand[key].ndim == 3:
  760. dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=1)
  761. elif (
  762. dict_to_expand[key] is not None
  763. and isinstance(dict_to_expand[key], torch.Tensor)
  764. and key not in visual_keys
  765. ):
  766. dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
  767. return dict_to_expand
  768. model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
  769. if input_ids is not None:
  770. input_ids = input_ids.repeat_interleave(expand_size, dim=0)
  771. model_kwargs = _expand_dict_for_generation(model_kwargs)
  772. if is_encoder_decoder:
  773. if model_kwargs.get("encoder_outputs") is None:
  774. raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
  775. model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
  776. return input_ids, model_kwargs
  777. __all__ = ["Glm46VModel", "Glm46VPreTrainedModel", "Glm46VForConditionalGeneration"]