output_capturing.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. # Copyright 2026 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Contains the logic for automatic additional output capture with our forward decorators.
  16. This mostly describe the hooks used and the logic to make capture thread/context safe.
  17. """
  18. from __future__ import annotations
  19. import threading
  20. from contextvars import ContextVar
  21. from dataclasses import dataclass
  22. from functools import wraps
  23. from typing import TYPE_CHECKING
  24. from .import_utils import is_torchdynamo_compiling, requires
  25. if TYPE_CHECKING:
  26. from torch import nn
  27. from ..modeling_utils import PreTrainedModel
  28. _CAN_RECORD_REGISTRY = {}
  29. @dataclass
  30. @requires(backends=("torch",))
  31. class OutputRecorder:
  32. """
  33. Configuration for recording outputs from a model via hooks.
  34. Attributes:
  35. target_class (Type): The class (e.g., nn.Module) to which the hook will be attached.
  36. index (Optional[int]): If the output is a tuple/list, optionally record only at a specific index.
  37. layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
  38. class_name (Optional[str]): Name of the class to which the hook will be attached. Could be the suffix of class name in some cases.
  39. """
  40. target_class: type[nn.Module]
  41. index: int = 0
  42. layer_name: str | None = None
  43. class_name: str | None = None
  44. class CompileableContextVar:
  45. """
  46. Convenience wrapper around a ContextVar for usage with `torch.compile`.
  47. This behaves exactly as a `ContextVar`, except when compilation is triggered in which case it behaves as a simple
  48. global variable. This is useful as `torch.compile` cannot trace the `get` method of `ContextVar`. This however means
  49. that the access to the underlying variable is not thread-safe when compilation is triggered.
  50. """
  51. def __init__(self, name):
  52. self.context_var = ContextVar(name, default=None)
  53. self.global_var = None
  54. self.compiling = False
  55. def get(self):
  56. # Set was called before and compilation was already detected
  57. if self.compiling:
  58. return self.global_var
  59. else:
  60. return self.context_var.get()
  61. def set(self, value):
  62. if is_torchdynamo_compiling():
  63. self.global_var = value
  64. self.compiling = True
  65. return None
  66. else:
  67. return self.context_var.set(value)
  68. def reset(self, token):
  69. if self.compiling or token is None:
  70. self.global_var = None
  71. self.compiling = False
  72. else:
  73. self.context_var.reset(token)
  74. # Thread/context-safe global variable
  75. _active_collector = CompileableContextVar("output_collector")
  76. def install_output_capuring_hook(module: nn.Module, key: str, index: int) -> None:
  77. """Install the forward hook needed to capture the output described by `key` and `index` in `module`."""
  78. def output_capturing_hook(module, args, output):
  79. # Get the current thread-local collector
  80. collected_outputs = _active_collector.get()
  81. # If it's None or not a key we want to capture, simply return, the hook is inactive
  82. if collected_outputs is None or key not in collected_outputs.keys():
  83. return
  84. if key == "hidden_states" and len(collected_outputs[key]) == 0:
  85. collected_outputs[key].append(args[0])
  86. if not isinstance(output, tuple):
  87. collected_outputs[key].append(output)
  88. elif output[index] is not None:
  89. collected_outputs[key].append(output[index])
  90. module.register_forward_hook(output_capturing_hook)
  91. def recursively_install_hooks(
  92. parent_module: nn.Module, module_name: str, capture_tasks: list[tuple[str, OutputRecorder]]
  93. ) -> None:
  94. """
  95. Recursively install all output capturing hooks on all submodules of `parent_module`.
  96. Note that we need to use this recursive approach instead of simply iterating over all modules, because we want
  97. to respect the `capture_tasks` of all individual submodels (`PreTrainedModel` instances) in the graph. That is, once
  98. we reach a submodel in the graph, its children should use this submodel's `capture_tasks`, but other parts of the graph
  99. should not.
  100. """
  101. from ..modeling_utils import PreTrainedModel
  102. # First dispatch to children if needed
  103. for name, module in parent_module.named_children():
  104. # Keep dispatching the same `capture_tasks`
  105. if not isinstance(module, PreTrainedModel):
  106. recursively_install_hooks(module, f"{module_name}.{name}", capture_tasks)
  107. # New Submodel: we need to dispatch its own `capture_tasks`
  108. else:
  109. install_all_output_capturing_hooks(module, prefix=f"{module_name}.{name}")
  110. # Potentially install the hook on current `parent_module`
  111. for key, specs in capture_tasks:
  112. # The second check is for multimodals where only backbone layer suffix is available
  113. if (specs.target_class is not None and isinstance(parent_module, specs.target_class)) or (
  114. specs.class_name is not None and module_name.endswith(specs.class_name)
  115. ):
  116. if specs.layer_name is not None and specs.layer_name not in module_name:
  117. continue
  118. install_output_capuring_hook(parent_module, key, specs.index)
  119. def install_all_output_capturing_hooks(model: PreTrainedModel, prefix: str | None = None) -> None:
  120. """
  121. Install the output recording hooks on all the modules in `model`. Tis will take care of correctly dispatching
  122. the `_can_record_outputs` property of each individual submodels in case of composite models.
  123. """
  124. # _can_record_outputs is None by default
  125. capture_flags = _CAN_RECORD_REGISTRY.get(str(model.__class__)) or {} # there is a weak ref for executorch
  126. capture_tasks = []
  127. for key, layer_specs in capture_flags.items():
  128. if not isinstance(layer_specs, list):
  129. layer_specs = [layer_specs]
  130. for specs in layer_specs:
  131. if not isinstance(specs, OutputRecorder):
  132. index = 0 if "hidden_states" in key else 1
  133. class_name = None if not isinstance(specs, str) else specs
  134. target_class = specs if not isinstance(specs, str) else None
  135. specs = OutputRecorder(target_class=target_class, index=index, class_name=class_name)
  136. capture_tasks.append((key, specs))
  137. # Install the hooks
  138. prefix = prefix if prefix is not None else ""
  139. recursively_install_hooks(model, prefix, capture_tasks)
  140. # Mark the model as already hooked
  141. setattr(model, "_output_capturing_hooks_installed", True)
  142. # We need this to make sure we don't have race conditions when installing hooks, resulting in them being installed
  143. # several times
  144. _hook_installation_lock = threading.Lock()
  145. def maybe_install_capturing_hooks(model: PreTrainedModel) -> None:
  146. """
  147. Check if the model already has output capturing hooks installed, and install them if it is not already the
  148. case.
  149. Note that this is thread-safe, in case 2 (or more) threads want to install them concurrently.
  150. """
  151. # First check
  152. if getattr(model, "_output_capturing_hooks_installed", False):
  153. return
  154. with _hook_installation_lock:
  155. # Second check, in case several threads entered this function concurrently and did not return on the
  156. # previous check
  157. if getattr(model, "_output_capturing_hooks_installed", False):
  158. return
  159. # This will install the hooks and mark the model as hooked
  160. install_all_output_capturing_hooks(model)
  161. def capture_outputs(func=None, *, tie_last_hidden_states=True):
  162. """
  163. Decorator to intercept specific layer outputs through hooks. The hooks are installed only once and lazily,
  164. the first time output capture is requested with the `output_xxx` kwargs/config.
  165. The implementation is fully context/thread safe, except when using `torch.compile`, as dynamo is unable to trace
  166. through `ContextVar` methods.
  167. Args:
  168. tie_last_hidden_states (`bool`, *optional*, defaults to `True`):
  169. Whether to overwrite `out.hidden_states[-1]` with the `out.last_hidden_state`.
  170. This is true for all language models and should be toggled off only if
  171. `out.hidden_states[-1]` has to be the hidden state before last layer norm, which
  172. is needed for some vision models (e.g. CLIP, SigLIP)
  173. """
  174. def wrapped_fn(func):
  175. @wraps(func)
  176. def wrapper(self, *args, **kwargs):
  177. # Pop it so that internal modules always return a dict even if False is requested
  178. return_dict = kwargs.pop("return_dict", getattr(self.config, "return_dict", True))
  179. # _can_record_outputs is None by default
  180. capturable_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__)) or {}
  181. recordable_keys = {
  182. f"output_{k}": kwargs.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
  183. for k in capturable_flags
  184. }
  185. # For BC as cross-attentions used to be captured with `output_attentions`
  186. if "cross_attentions" in capturable_flags:
  187. recordable_keys["output_cross_attentions"] = kwargs.get(
  188. "output_attentions", getattr(self.config, "output_attentions", False)
  189. )
  190. # The sam model variants need this annoying exception as well...
  191. if "mask_decoder_attentions" in capturable_flags:
  192. recordable_keys["output_mask_decoder_attentions"] = kwargs.get(
  193. "output_attentions", getattr(self.config, "output_attentions", False)
  194. )
  195. collected_outputs = {k.replace("output_", ""): [] for k, v in recordable_keys.items() if v}
  196. # Make sure hooks are installed if we need to collect outputs
  197. if len(collected_outputs) > 0:
  198. maybe_install_capturing_hooks(self)
  199. # Let's activate the output collector hooks if needed!
  200. output_token = _active_collector.set(collected_outputs)
  201. # Run the forward
  202. try:
  203. outputs = func(self, *args, **kwargs)
  204. # Reset the states
  205. finally:
  206. _active_collector.reset(output_token)
  207. # Inject collected outputs into model output (return everything as tuples for BC)
  208. for key in collected_outputs:
  209. if key == "hidden_states":
  210. if not tie_last_hidden_states:
  211. pass
  212. elif hasattr(outputs, "vision_hidden_states"):
  213. collected_outputs[key] = collected_outputs[key][:-1]
  214. collected_outputs[key].append(outputs.vision_hidden_states)
  215. elif hasattr(outputs, "last_hidden_state"):
  216. collected_outputs[key] = collected_outputs[key][:-1]
  217. collected_outputs[key].append(outputs.last_hidden_state)
  218. outputs[key] = tuple(collected_outputs[key])
  219. elif key == "attentions":
  220. # In this case, the second item are cross attentions
  221. if isinstance(capturable_flags[key], list) and len(capturable_flags[key]) == 2:
  222. outputs[key] = tuple(collected_outputs[key][0::2])
  223. outputs["cross_" + key] = tuple(collected_outputs[key][1::2])
  224. else:
  225. outputs[key] = tuple(collected_outputs[key])
  226. else:
  227. outputs[key] = tuple(collected_outputs[key])
  228. if return_dict is False:
  229. outputs = outputs.to_tuple()
  230. return outputs
  231. return wrapper
  232. if func is not None:
  233. return wrapped_fn(func)
  234. return wrapped_fn