| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- # Copyright 2026 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Contains the logic for automatic additional output capture with our forward decorators.
- This mostly describe the hooks used and the logic to make capture thread/context safe.
- """
- from __future__ import annotations
- import threading
- from contextvars import ContextVar
- from dataclasses import dataclass
- from functools import wraps
- from typing import TYPE_CHECKING
- from .import_utils import is_torchdynamo_compiling, requires
- if TYPE_CHECKING:
- from torch import nn
- from ..modeling_utils import PreTrainedModel
- _CAN_RECORD_REGISTRY = {}
- @dataclass
- @requires(backends=("torch",))
- class OutputRecorder:
- """
- Configuration for recording outputs from a model via hooks.
- Attributes:
- target_class (Type): The class (e.g., nn.Module) to which the hook will be attached.
- index (Optional[int]): If the output is a tuple/list, optionally record only at a specific index.
- layer_name (Optional[str]): Name of the submodule to target (if needed), e.g., "transformer.layer.3.attn".
- class_name (Optional[str]): Name of the class to which the hook will be attached. Could be the suffix of class name in some cases.
- """
- target_class: type[nn.Module]
- index: int = 0
- layer_name: str | None = None
- class_name: str | None = None
- class CompileableContextVar:
- """
- Convenience wrapper around a ContextVar for usage with `torch.compile`.
- This behaves exactly as a `ContextVar`, except when compilation is triggered in which case it behaves as a simple
- global variable. This is useful as `torch.compile` cannot trace the `get` method of `ContextVar`. This however means
- that the access to the underlying variable is not thread-safe when compilation is triggered.
- """
- def __init__(self, name):
- self.context_var = ContextVar(name, default=None)
- self.global_var = None
- self.compiling = False
- def get(self):
- # Set was called before and compilation was already detected
- if self.compiling:
- return self.global_var
- else:
- return self.context_var.get()
- def set(self, value):
- if is_torchdynamo_compiling():
- self.global_var = value
- self.compiling = True
- return None
- else:
- return self.context_var.set(value)
- def reset(self, token):
- if self.compiling or token is None:
- self.global_var = None
- self.compiling = False
- else:
- self.context_var.reset(token)
- # Thread/context-safe global variable
- _active_collector = CompileableContextVar("output_collector")
- def install_output_capuring_hook(module: nn.Module, key: str, index: int) -> None:
- """Install the forward hook needed to capture the output described by `key` and `index` in `module`."""
- def output_capturing_hook(module, args, output):
- # Get the current thread-local collector
- collected_outputs = _active_collector.get()
- # If it's None or not a key we want to capture, simply return, the hook is inactive
- if collected_outputs is None or key not in collected_outputs.keys():
- return
- if key == "hidden_states" and len(collected_outputs[key]) == 0:
- collected_outputs[key].append(args[0])
- if not isinstance(output, tuple):
- collected_outputs[key].append(output)
- elif output[index] is not None:
- collected_outputs[key].append(output[index])
- module.register_forward_hook(output_capturing_hook)
- def recursively_install_hooks(
- parent_module: nn.Module, module_name: str, capture_tasks: list[tuple[str, OutputRecorder]]
- ) -> None:
- """
- Recursively install all output capturing hooks on all submodules of `parent_module`.
- Note that we need to use this recursive approach instead of simply iterating over all modules, because we want
- to respect the `capture_tasks` of all individual submodels (`PreTrainedModel` instances) in the graph. That is, once
- we reach a submodel in the graph, its children should use this submodel's `capture_tasks`, but other parts of the graph
- should not.
- """
- from ..modeling_utils import PreTrainedModel
- # First dispatch to children if needed
- for name, module in parent_module.named_children():
- # Keep dispatching the same `capture_tasks`
- if not isinstance(module, PreTrainedModel):
- recursively_install_hooks(module, f"{module_name}.{name}", capture_tasks)
- # New Submodel: we need to dispatch its own `capture_tasks`
- else:
- install_all_output_capturing_hooks(module, prefix=f"{module_name}.{name}")
- # Potentially install the hook on current `parent_module`
- for key, specs in capture_tasks:
- # The second check is for multimodals where only backbone layer suffix is available
- if (specs.target_class is not None and isinstance(parent_module, specs.target_class)) or (
- specs.class_name is not None and module_name.endswith(specs.class_name)
- ):
- if specs.layer_name is not None and specs.layer_name not in module_name:
- continue
- install_output_capuring_hook(parent_module, key, specs.index)
- def install_all_output_capturing_hooks(model: PreTrainedModel, prefix: str | None = None) -> None:
- """
- Install the output recording hooks on all the modules in `model`. Tis will take care of correctly dispatching
- the `_can_record_outputs` property of each individual submodels in case of composite models.
- """
- # _can_record_outputs is None by default
- capture_flags = _CAN_RECORD_REGISTRY.get(str(model.__class__)) or {} # there is a weak ref for executorch
- capture_tasks = []
- for key, layer_specs in capture_flags.items():
- if not isinstance(layer_specs, list):
- layer_specs = [layer_specs]
- for specs in layer_specs:
- if not isinstance(specs, OutputRecorder):
- index = 0 if "hidden_states" in key else 1
- class_name = None if not isinstance(specs, str) else specs
- target_class = specs if not isinstance(specs, str) else None
- specs = OutputRecorder(target_class=target_class, index=index, class_name=class_name)
- capture_tasks.append((key, specs))
- # Install the hooks
- prefix = prefix if prefix is not None else ""
- recursively_install_hooks(model, prefix, capture_tasks)
- # Mark the model as already hooked
- setattr(model, "_output_capturing_hooks_installed", True)
- # We need this to make sure we don't have race conditions when installing hooks, resulting in them being installed
- # several times
- _hook_installation_lock = threading.Lock()
- def maybe_install_capturing_hooks(model: PreTrainedModel) -> None:
- """
- Check if the model already has output capturing hooks installed, and install them if it is not already the
- case.
- Note that this is thread-safe, in case 2 (or more) threads want to install them concurrently.
- """
- # First check
- if getattr(model, "_output_capturing_hooks_installed", False):
- return
- with _hook_installation_lock:
- # Second check, in case several threads entered this function concurrently and did not return on the
- # previous check
- if getattr(model, "_output_capturing_hooks_installed", False):
- return
- # This will install the hooks and mark the model as hooked
- install_all_output_capturing_hooks(model)
- def capture_outputs(func=None, *, tie_last_hidden_states=True):
- """
- Decorator to intercept specific layer outputs through hooks. The hooks are installed only once and lazily,
- the first time output capture is requested with the `output_xxx` kwargs/config.
- The implementation is fully context/thread safe, except when using `torch.compile`, as dynamo is unable to trace
- through `ContextVar` methods.
- Args:
- tie_last_hidden_states (`bool`, *optional*, defaults to `True`):
- Whether to overwrite `out.hidden_states[-1]` with the `out.last_hidden_state`.
- This is true for all language models and should be toggled off only if
- `out.hidden_states[-1]` has to be the hidden state before last layer norm, which
- is needed for some vision models (e.g. CLIP, SigLIP)
- """
- def wrapped_fn(func):
- @wraps(func)
- def wrapper(self, *args, **kwargs):
- # Pop it so that internal modules always return a dict even if False is requested
- return_dict = kwargs.pop("return_dict", getattr(self.config, "return_dict", True))
- # _can_record_outputs is None by default
- capturable_flags = _CAN_RECORD_REGISTRY.get(str(self.__class__)) or {}
- recordable_keys = {
- f"output_{k}": kwargs.get(f"output_{k}", getattr(self.config, f"output_{k}", False))
- for k in capturable_flags
- }
- # For BC as cross-attentions used to be captured with `output_attentions`
- if "cross_attentions" in capturable_flags:
- recordable_keys["output_cross_attentions"] = kwargs.get(
- "output_attentions", getattr(self.config, "output_attentions", False)
- )
- # The sam model variants need this annoying exception as well...
- if "mask_decoder_attentions" in capturable_flags:
- recordable_keys["output_mask_decoder_attentions"] = kwargs.get(
- "output_attentions", getattr(self.config, "output_attentions", False)
- )
- collected_outputs = {k.replace("output_", ""): [] for k, v in recordable_keys.items() if v}
- # Make sure hooks are installed if we need to collect outputs
- if len(collected_outputs) > 0:
- maybe_install_capturing_hooks(self)
- # Let's activate the output collector hooks if needed!
- output_token = _active_collector.set(collected_outputs)
- # Run the forward
- try:
- outputs = func(self, *args, **kwargs)
- # Reset the states
- finally:
- _active_collector.reset(output_token)
- # Inject collected outputs into model output (return everything as tuples for BC)
- for key in collected_outputs:
- if key == "hidden_states":
- if not tie_last_hidden_states:
- pass
- elif hasattr(outputs, "vision_hidden_states"):
- collected_outputs[key] = collected_outputs[key][:-1]
- collected_outputs[key].append(outputs.vision_hidden_states)
- elif hasattr(outputs, "last_hidden_state"):
- collected_outputs[key] = collected_outputs[key][:-1]
- collected_outputs[key].append(outputs.last_hidden_state)
- outputs[key] = tuple(collected_outputs[key])
- elif key == "attentions":
- # In this case, the second item are cross attentions
- if isinstance(capturable_flags[key], list) and len(capturable_flags[key]) == 2:
- outputs[key] = tuple(collected_outputs[key][0::2])
- outputs["cross_" + key] = tuple(collected_outputs[key][1::2])
- else:
- outputs[key] = tuple(collected_outputs[key])
- else:
- outputs[key] = tuple(collected_outputs[key])
- if return_dict is False:
- outputs = outputs.to_tuple()
- return outputs
- return wrapper
- if func is not None:
- return wrapped_fn(func)
- return wrapped_fn
|