attention_visualizer.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import io
  15. import httpx
  16. from PIL import Image
  17. from ..masking_utils import create_causal_mask
  18. from ..models.auto.auto_factory import _get_model_class
  19. from ..models.auto.configuration_auto import AutoConfig
  20. from ..models.auto.modeling_auto import MODEL_FOR_PRETRAINING_MAPPING, MODEL_MAPPING
  21. from ..models.auto.processing_auto import PROCESSOR_MAPPING_NAMES, AutoProcessor
  22. from ..models.auto.tokenization_auto import AutoTokenizer
  23. from .import_utils import is_torch_available
  24. if is_torch_available():
  25. import torch
  26. import torch.nn as nn
  27. # Print the matrix with words as row labels
  28. GREEN = "\033[92m"
  29. YELLOW = "\033[93m"
  30. RESET = "\033[0m"
  31. BLACK_SQUARE = "■"
  32. WHITE_SQUARE = "⬚"
  33. def generate_attention_matrix_from_mask(
  34. words, mask, img_token="<img>", sliding_window=None, token_type_ids=None, image_seq_length=None
  35. ):
  36. """
  37. Generates an attention matrix from a given attention mask.
  38. Optionally applies a sliding window mask (e.g., for Gemma2/3) and
  39. marks regions where image tokens occur based on the specified `img_token`.
  40. """
  41. mask = mask.int()
  42. if mask.ndim == 3:
  43. mask = mask[0, :, :]
  44. if mask.ndim == 4:
  45. mask = mask[0, 0, :, :]
  46. n = len(words)
  47. max_word_length = max(len(repr(word)) for word in words)
  48. first_img_idx = 0
  49. output = []
  50. for i, k in enumerate(words):
  51. if k == img_token and not first_img_idx:
  52. first_img_idx = i
  53. mask[i, i] = 2 # Mark yellow regions
  54. if first_img_idx > 0 and (k != img_token or i == n - 1):
  55. if i == n - 1:
  56. i += 1
  57. mask[first_img_idx:i, first_img_idx:i] = 2 # Mark yellow regions
  58. first_img_idx = 0
  59. # Generate sliding window mask (size = 4), excluding img_token
  60. sliding_window_mask = None
  61. if sliding_window is not None:
  62. sliding_window_mask = [[1 if (0 <= i - j < sliding_window) else 0 for j in range(n)] for i in range(n)]
  63. row_dummy = " ".join(
  64. f"{YELLOW}{BLACK_SQUARE}{RESET}"
  65. if mask[0, j]
  66. else f"{GREEN}{BLACK_SQUARE}{RESET}"
  67. if j == 0
  68. else BLACK_SQUARE
  69. if mask[0, j]
  70. else WHITE_SQUARE
  71. for j in range(n)
  72. )
  73. if token_type_ids is not None:
  74. is_special = token_type_ids == 1
  75. token_type_buckets = torch.where(
  76. (token_type_ids.cumsum(-1) % 5 + is_special).bool(), token_type_ids.cumsum(-1), 0
  77. )
  78. boundaries = torch.arange(0, image_seq_length + 1, image_seq_length)
  79. token_type_buckets = torch.bucketize(token_type_buckets, boundaries=boundaries)
  80. # Print headers
  81. legend = f"{GREEN}{BLACK_SQUARE}{RESET}: i == j (diagonal) {YELLOW}{BLACK_SQUARE}{RESET}: token_type_ids"
  82. output.append(" " + legend)
  83. f_string = " " * (max_word_length + 5) + "Attention Matrix".ljust(len(row_dummy) // 2)
  84. if sliding_window is not None:
  85. f_string += "Sliding Window Mask"
  86. output.append(f_string)
  87. vertical_header = []
  88. for idx, word in enumerate(words):
  89. if mask[idx, idx] == 2:
  90. vertical_header.append([f"{YELLOW}{k}{RESET}" for k in list(str(idx).rjust(len(str(n))))])
  91. else:
  92. vertical_header.append(list(str(idx).rjust(len(str(n)))))
  93. vertical_header = list(map(list, zip(*vertical_header))) # Transpose
  94. for row in vertical_header:
  95. output.append(
  96. (max_word_length + 5) * " " + " ".join(row) + " | " + " ".join(row)
  97. if sliding_window is not None
  98. else ""
  99. )
  100. for i, word in enumerate(words):
  101. word_repr = repr(word).ljust(max_word_length)
  102. colored_word = f"{YELLOW}{word_repr}{RESET}" if img_token in word else word_repr
  103. row_display = " ".join(
  104. f"{YELLOW}{BLACK_SQUARE}{RESET}"
  105. if img_token in words[j] and mask[i, j] and img_token in word
  106. else f"{GREEN}{BLACK_SQUARE}{RESET}"
  107. if i == j
  108. else BLACK_SQUARE
  109. if mask[i, j]
  110. else WHITE_SQUARE
  111. for j in range(n)
  112. )
  113. sliding_window_row = ""
  114. if sliding_window is not None:
  115. sliding_window_row = " ".join(
  116. f"{YELLOW}{BLACK_SQUARE}{RESET}"
  117. if img_token in words[j] and img_token in word and token_type_buckets[0, i] == token_type_buckets[0, j]
  118. else f"{GREEN}{BLACK_SQUARE}{RESET}"
  119. if i == j
  120. else BLACK_SQUARE
  121. if sliding_window_mask[i][j]
  122. else WHITE_SQUARE
  123. for j in range(n)
  124. )
  125. output.append(f"{colored_word}: {str(i).rjust(2)} {row_display} | {sliding_window_row}")
  126. return "\n".join(output)
  127. class AttentionMaskVisualizer:
  128. def __init__(self, model_name: str):
  129. config = AutoConfig.from_pretrained(model_name)
  130. self.image_token = "<img>"
  131. if hasattr(config.get_text_config(), "sliding_window"):
  132. self.sliding_window = getattr(config.get_text_config(), "sliding_window", None)
  133. try:
  134. mapped_cls = _get_model_class(config, MODEL_MAPPING)
  135. except Exception:
  136. mapped_cls = _get_model_class(config, MODEL_FOR_PRETRAINING_MAPPING)
  137. if mapped_cls is None:
  138. raise ValueError(f"Model name {model_name} is not supported for attention visualization")
  139. self.mapped_cls = mapped_cls
  140. class _ModelWrapper(mapped_cls, nn.Module):
  141. def __init__(self, config, model_name):
  142. nn.Module.__init__(self)
  143. self.dummy_module = nn.Linear(1, 1)
  144. self.config = config
  145. self.model = _ModelWrapper(config, model_name)
  146. self.model.to(config.dtype)
  147. self.repo_id = model_name
  148. self.config = config
  149. def __call__(self, input_sentence: str, suffix=""):
  150. self.visualize_attention_mask(input_sentence, suffix=suffix)
  151. def visualize_attention_mask(self, input_sentence: str, suffix=""):
  152. model = self.model
  153. kwargs = {}
  154. image_seq_length = None
  155. if self.config.model_type in PROCESSOR_MAPPING_NAMES:
  156. img = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg?download=true"
  157. img = Image.open(io.BytesIO(httpx.get(img, follow_redirects=True).content))
  158. image_seq_length = 5
  159. processor = AutoProcessor.from_pretrained(self.repo_id, image_seq_length=image_seq_length)
  160. if hasattr(processor, "image_token"):
  161. image_token = processor.image_token
  162. else:
  163. image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
  164. if image_token:
  165. input_sentence = input_sentence.replace("<img>", image_token)
  166. inputs = processor(images=img, text=input_sentence, suffix=suffix, return_tensors="pt")
  167. self.image_token = processor.tokenizer.convert_ids_to_tokens([processor.image_token_id])[0]
  168. attention_mask = inputs["attention_mask"]
  169. if "token_type_ids" in inputs: # TODO inspect signature of update causal mask
  170. kwargs["token_type_ids"] = inputs["token_type_ids"]
  171. tokens = processor.tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
  172. else:
  173. tokenizer = AutoTokenizer.from_pretrained(self.repo_id)
  174. if tokenizer is None:
  175. raise ValueError(f"Could not load tokenizer for {self.repo_id}")
  176. tokens = tokenizer.tokenize(input_sentence)
  177. attention_mask = tokenizer(input_sentence, return_tensors="pt")["attention_mask"]
  178. if attention_mask is None:
  179. raise ValueError(f"Model type {self.config.model_type} does not support attention visualization")
  180. model.config._attn_implementation = "eager"
  181. model.train()
  182. batch_size, seq_length = attention_mask.shape
  183. inputs_embeds = torch.zeros((batch_size, seq_length, model.config.hidden_size), dtype=self.model.dtype)
  184. causal_mask = create_causal_mask(
  185. config=model.config,
  186. inputs_embeds=inputs_embeds,
  187. attention_mask=attention_mask,
  188. past_key_values=None,
  189. )
  190. if causal_mask is None:
  191. # attention_mask must be a tensor here
  192. attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).expand(batch_size, 1, seq_length, seq_length)
  193. elif isinstance(causal_mask, torch.Tensor):
  194. attention_mask = ~causal_mask.to(dtype=torch.bool)
  195. else:
  196. attention_mask = ~causal_mask
  197. top_bottom_border = "##" * (
  198. len(f"Attention visualization for {self.config.model_type} | {self.mapped_cls}") + 4
  199. ) # Box width adjusted to text length
  200. side_border = "##"
  201. print(f"\n{top_bottom_border}")
  202. print(
  203. "##"
  204. + f" Attention visualization for \033[1m{self.config.model_type}:{self.repo_id}\033[0m {self.mapped_cls.__name__}".center(
  205. len(top_bottom_border)
  206. )
  207. + " "
  208. + side_border,
  209. )
  210. print(f"{top_bottom_border}")
  211. f_string = generate_attention_matrix_from_mask(
  212. tokens,
  213. attention_mask,
  214. img_token=self.image_token,
  215. sliding_window=getattr(self.config, "sliding_window", None),
  216. token_type_ids=kwargs.get("token_type_ids"),
  217. image_seq_length=image_seq_length,
  218. )
  219. print(f_string)
  220. print(f"{top_bottom_border}")