monkey_patching.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright 2026 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import re
  15. import sys
  16. import threading
  17. from contextlib import contextmanager
  18. from .utils import is_torch_available, logging
  19. from .utils.output_capturing import OutputRecorder
  20. if is_torch_available():
  21. import torch.nn as nn
  22. logger = logging.get_logger(__name__)
  23. _monkey_patch_mapping_cache: dict[str, type[nn.Module]] = {}
  24. _compiled_patterns_cache: dict[str, re.Pattern] = {}
  25. _monkey_patch_lock = threading.Lock()
  26. def _compile_pattern(pattern: str) -> re.Pattern | None:
  27. """
  28. Compile a regex pattern and cache it. Returns None if pattern is invalid.
  29. Args:
  30. pattern: The regex pattern string to compile
  31. Returns:
  32. Compiled regex pattern or None if invalid
  33. """
  34. if pattern in _compiled_patterns_cache:
  35. return _compiled_patterns_cache[pattern]
  36. try:
  37. compiled = re.compile(pattern)
  38. _compiled_patterns_cache[pattern] = compiled
  39. return compiled
  40. except re.error as e:
  41. logger.warning(f"Invalid regex pattern '{pattern}': {e}. Treating as non-pattern.")
  42. return None
  43. def _find_replacement_class(class_name: str, mapping: dict[str, type[nn.Module]]) -> type[nn.Module] | None:
  44. """
  45. Find replacement class for a given class name, checking exact matches first, then regex patterns.
  46. Args:
  47. class_name: The class name to find a replacement for
  48. mapping: Dictionary of patterns/names to replacement classes
  49. Returns:
  50. The replacement class if found, None otherwise
  51. """
  52. # First check for exact match (highest priority)
  53. if class_name in mapping:
  54. return mapping[class_name]
  55. # Then check regex patterns
  56. for pattern, replacement_class in mapping.items():
  57. # Skip if already matched as exact
  58. if pattern == class_name:
  59. continue
  60. # Try to compile and match as regex
  61. compiled_pattern = _compile_pattern(pattern)
  62. if compiled_pattern is not None and compiled_pattern.search(class_name):
  63. return replacement_class
  64. return None
  65. def register_patch_mapping(mapping: dict[str, type[nn.Module]], overwrite: bool = False) -> None:
  66. """
  67. Register patch mappings to enable automatic patching during model creation using `from_pretrained`,
  68. `from_config` or within the `apply_patches` context manager.
  69. Use this to register class replacements that will be automatically applied when loading any model.
  70. This is useful for quantization library compatibility, structural optimizations, and architectural
  71. experimentation. The mapping is global, can grow with multiple calls, and can be cleared entirely.
  72. Args:
  73. mapping (`Dict[str, type[nn.Module]]`):
  74. Mapping from original class names (or regex patterns) to replacement classes. Supports:
  75. - Exact class names: `"Qwen2MoeExperts"` → `CustomExperts`
  76. - Regex patterns: `".*Attention"` matches `LlamaAttention`, `MistralAttention`, etc.,
  77. or `"^Llama\\d+Attention$"` matches `Llama2Attention`, `Llama3Attention`, etc.
  78. Exact matches take precedence over patterns. Patterns are matched using `re.search()`,
  79. so they can match anywhere in the class name unless you use anchors (`^` for start, `$` for end).
  80. overwrite (`bool`, *optional*, defaults to `False`):
  81. Whether to overwrite existing mappings for class names that are already registered.
  82. Example:
  83. ```python
  84. from transformers import AutoModelForCausalLM
  85. from transformers.monkey_patching import register_patch_mapping
  86. # Define custom expert implementation
  87. class SequentialExperts(nn.Module):
  88. ...
  89. # Register exact class name
  90. register_patch_mapping(
  91. mapping={"Qwen2MoeExperts": SequentialExperts}
  92. )
  93. # Register with regex pattern to match multiple classes
  94. register_patch_mapping(
  95. mapping={".*Attention": CustomAttention} # Matches LlamaAttention, MistralAttention, etc.
  96. )
  97. # Match specific model versions
  98. register_patch_mapping(
  99. mapping={"^Llama\\d+Attention$": CustomLlamaAttention} # Matches Llama2Attention, Llama3Attention
  100. )
  101. # The patch will be automatically applied during loading
  102. model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
  103. ```
  104. Note:
  105. For weight conversions, use [`~transformers.register_checkpoint_conversion_mapping`] instead.
  106. """
  107. global _monkey_patch_mapping_cache
  108. with _monkey_patch_lock:
  109. for class_name, replacement_class in mapping.items():
  110. # Validate that replacement_class is actually a class and is a subclass of nn.Module
  111. if not isinstance(replacement_class, type):
  112. raise TypeError(
  113. f"Replacement for '{class_name}' must be a class, got {type(replacement_class).__name__}"
  114. )
  115. if not issubclass(replacement_class, nn.Module):
  116. raise TypeError(
  117. f"Replacement class for '{class_name}' must be a subclass of nn.Module, "
  118. f"got {replacement_class.__name__} which inherits from {[c.__name__ for c in replacement_class.__mro__[1:]]}"
  119. )
  120. if class_name in _monkey_patch_mapping_cache and not overwrite:
  121. raise ValueError(
  122. f"Class '{class_name}' already has a patch mapping registered. Use overwrite=True to replace it."
  123. )
  124. _monkey_patch_mapping_cache[class_name] = replacement_class
  125. def unregister_patch_mapping(keys: list[str]) -> None:
  126. """
  127. Unregister patch mappings to disable automatic patching.
  128. This removes specified mappings from the global registry, preventing them from being applied
  129. during model loading. You must provide the exact same name or pattern that was used during registration.
  130. Args:
  131. keys (`List[str]`):
  132. List of mapping keys (class names or regex patterns) to remove from the patch mapping
  133. (e.g., `["Qwen2MoeExperts"]` or `[".*Attention"]`).
  134. Example:
  135. ```python
  136. from transformers import AutoModelForCausalLM
  137. from transformers.monkey_patching import register_patch_mapping, unregister_patch_mapping
  138. # Register a patch
  139. register_patch_mapping(
  140. mapping={"Qwen2MoeExperts": CustomExperts}
  141. )
  142. # Unregister the patch
  143. unregister_patch_mapping(["Qwen2MoeExperts"])
  144. # The patch will no longer be applied during loading
  145. model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-MoE-A2.7B")
  146. ```
  147. """
  148. global _monkey_patch_mapping_cache
  149. with _monkey_patch_lock:
  150. for key in keys:
  151. if key not in _monkey_patch_mapping_cache:
  152. raise ValueError(
  153. f"Class or pattern '{key}' not found in monkey patch mapping cache. "
  154. f"Cannot unregister a class that is not registered."
  155. )
  156. del _monkey_patch_mapping_cache[key]
  157. def get_patch_mapping() -> dict[str, type[nn.Module]]:
  158. """
  159. Get all registered patch mappings.
  160. Returns:
  161. `Dict[str, type[nn.Module]]`: Dictionary mapping class names or patterns to replacement classes.
  162. """
  163. with _monkey_patch_lock:
  164. return _monkey_patch_mapping_cache.copy()
  165. def clear_patch_mapping() -> None:
  166. """
  167. Clear all registered patch mappings.
  168. This removes all registered mappings from the global registry.
  169. Example:
  170. ```python
  171. from transformers.monkey_patching import register_patch_mapping, clear_patch_mapping
  172. # Register some patches
  173. register_patch_mapping(
  174. mapping={"Qwen2MoeExperts": CustomExperts}
  175. )
  176. # Clear all patches
  177. clear_patch_mapping()
  178. ```
  179. """
  180. global _monkey_patch_mapping_cache
  181. with _monkey_patch_lock:
  182. _monkey_patch_mapping_cache.clear()
  183. @contextmanager
  184. def apply_patches():
  185. """
  186. Context manager to apply registered monkey patches within a block of code.
  187. This temporarily replaces original classes with their registered replacements during the execution of the block, and restores the original classes afterward.
  188. Example:
  189. ```python
  190. from transformers import Qwen2MoeModel, Qwen2MoeConfig
  191. from transformers.monkey_patching import register_patch_mapping, apply_patches
  192. # Register a patch
  193. register_patch_mapping(
  194. mapping={"Qwen2MoeExperts": CustomExperts}
  195. )
  196. # Apply patches within the context
  197. with apply_patches():
  198. # The model will use CustomExperts instead of Qwen2MoeExperts
  199. model = Qwen2MoeModel(Qwen2MoeConfig())
  200. # Outside the context, original classes are restored
  201. # The model will use Qwen2MoeExperts again
  202. model = Qwen2MoeModel(Qwen2MoeConfig())
  203. ```
  204. """
  205. mapping = get_patch_mapping()
  206. if not mapping:
  207. yield
  208. return
  209. original_classes = {}
  210. # Create list to avoid dict changed during iteration
  211. for module in list(sys.modules.values()):
  212. if module is None or not hasattr(module, "__name__"):
  213. continue
  214. if not module.__name__.startswith("transformers"):
  215. continue
  216. # Iterate through all attributes in transformers modules
  217. for attr_name in dir(module):
  218. # Check if this attribute name matches any pattern before accessing it
  219. replacement_class = _find_replacement_class(attr_name, mapping)
  220. if replacement_class is None:
  221. continue
  222. try:
  223. attr = getattr(module, attr_name)
  224. # Check if it's a class
  225. if not isinstance(attr, type):
  226. continue
  227. original_classes[(module.__name__, attr_name)] = attr
  228. setattr(module, attr_name, replacement_class)
  229. except (AttributeError, TypeError, ImportError):
  230. # Skip attributes that can't be accessed or modules that can't be imported
  231. continue
  232. yield
  233. for (module_name, class_name), original_class in original_classes.items():
  234. module = sys.modules[module_name]
  235. setattr(module, class_name, original_class)
  236. # _can_record_outputs is a class attribute so patching and unpatching it in the class won't work
  237. # since the model instance will still reference the original class's _can_record_outputs.
  238. def patch_output_recorders(model: nn.Module) -> None:
  239. """
  240. Patch the model instance's output recorders to use the registered replacement classes.
  241. This function updates output recorders in a model's submodules to use monkey-patched replacement
  242. classes. Output recorders are used by the transformers library to track intermediate outputs during
  243. forward passes (via the `_can_record_outputs` attribute). When classes are monkey-patched, these
  244. recorders need to be updated to reference the new classes.
  245. This is automatically called during model initialization when loading with `from_pretrained` or
  246. `from_config`. You typically don't need to call this manually unless you're constructing models
  247. in custom ways.
  248. Note:
  249. The `_can_record_outputs` attribute is a class-level attribute that maps output names to either:
  250. - `OutputRecorder` instances that have a `target_class` attribute
  251. - Class types directly
  252. This function patches both cases to use the replacement classes from the monkey patch registry.
  253. Args:
  254. model (`nn.Module`):
  255. The model instance whose output recorders should be patched. All submodules will be
  256. traversed to find and patch their `_can_record_outputs` attributes.
  257. Example:
  258. ```python
  259. from transformers import AutoModelForCausalLM
  260. from transformers.monkey_patching import register_patch_mapping, patch_output_recorders
  261. # Register a patch
  262. register_patch_mapping(mapping={"Qwen2MoeExperts": CustomExperts})
  263. # If you construct a model manually (without from_pretrained), patch recorders
  264. model = Qwen2MoeModel(config)
  265. patch_output_recorders(model) # Updates output recorders to use CustomExperts
  266. ```
  267. """
  268. mapping = get_patch_mapping()
  269. if not mapping:
  270. return
  271. for submodule in model.modules():
  272. if hasattr(submodule, "_can_record_outputs") and submodule._can_record_outputs is not None:
  273. for output, recorder in submodule._can_record_outputs.items():
  274. if isinstance(recorder, OutputRecorder):
  275. # Check if target class matches any registered pattern or exact name
  276. replacement_class = _find_replacement_class(recorder.target_class.__name__, mapping)
  277. if replacement_class is not None:
  278. recorder.target_class = replacement_class
  279. elif isinstance(recorder, type):
  280. # Check if class type matches any registered pattern or exact name
  281. replacement_class = _find_replacement_class(recorder.__name__, mapping)
  282. if replacement_class is not None:
  283. submodule._can_record_outputs[output] = replacement_class