module_tracker.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # mypy: allow-untyped-defs
  2. import logging
  3. import weakref
  4. from typing import TYPE_CHECKING
  5. import torch
  6. from torch.autograd.graph import register_multi_grad_hook
  7. from torch.nn.modules.module import (
  8. register_module_forward_hook,
  9. register_module_forward_pre_hook,
  10. )
  11. from torch.utils._pytree import tree_flatten
  12. if TYPE_CHECKING:
  13. from torch.utils.hooks import RemovableHandle
  14. logger = logging.getLogger(__name__)
  15. __all__ = ["ModuleTracker"]
  16. class ModuleTracker:
  17. """
  18. ``ModuleTracker`` is a context manager that tracks the nn.Module hierarchy during execution
  19. so that other system can query which Module is currently being executed (or its backward is being
  20. executed).
  21. You can access the ``parents`` attribute on this context manager to get the set of all the
  22. Modules currently being executed via their fqn (fully qualified name, also used as the key within
  23. the state_dict).
  24. You can access the ``is_bw`` attribute to know if you are currently running in backward or not.
  25. Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag
  26. will remain ``True`` after the forward until another Module is executed. If you need it to be
  27. more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance
  28. is possible but not done yet, please submit an issue requesting this if you need it.
  29. Example usage
  30. .. code-block:: python
  31. mod = torch.nn.Linear(2, 2)
  32. with ModuleTracker() as tracker:
  33. # Access anything during the forward pass
  34. def my_linear(m1, m2, bias):
  35. print(f"Current modules: {tracker.parents}")
  36. return torch.mm(m1, m2.t()) + bias
  37. torch.nn.functional.linear = my_linear
  38. mod(torch.rand(2, 2))
  39. """
  40. parents: set[str]
  41. """
  42. A Set containing the fqn for each module currently running their forward
  43. """
  44. def __init__(self) -> None:
  45. self.parents = {"Global"}
  46. self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
  47. self._seen_modules: weakref.WeakSet = weakref.WeakSet()
  48. self._has_callback = False
  49. self._hooks: list[RemovableHandle] = []
  50. def _maybe_set_engine_callback(self) -> None:
  51. # This assumes no concurrent calls to backward
  52. if self._has_callback:
  53. return
  54. def callback() -> None:
  55. self.parents = {"Global"}
  56. self._has_callback = False
  57. torch.autograd.Variable._execution_engine.queue_callback(callback)
  58. self._has_callback = True
  59. @property
  60. def is_bw(self):
  61. """
  62. A boolean marking if this is currently running during the backward pass or not
  63. """
  64. return torch._C._current_graph_task_id() != -1
  65. def _get_mod_name(self, mod):
  66. if mod not in self._known_modules:
  67. self._known_modules[mod] = type(mod).__name__
  68. mod_name = self._known_modules[mod]
  69. if mod not in self._seen_modules:
  70. for name, submod in mod.named_children():
  71. self._known_modules[submod] = f"{mod_name}.{name}"
  72. self._get_mod_name(submod)
  73. self._seen_modules.add(mod)
  74. return mod_name
  75. def _get_append_fn(self, name, is_bw):
  76. def fn(*args) -> None:
  77. if is_bw:
  78. self._maybe_set_engine_callback()
  79. if name in self.parents:
  80. logger.info(
  81. "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
  82. name,
  83. "backward" if is_bw else "forward",
  84. )
  85. self.parents.add(name)
  86. return fn
  87. def _get_pop_fn(self, name, is_bw):
  88. def fn(*args) -> None:
  89. if name in self.parents:
  90. self.parents.remove(name)
  91. else:
  92. logger.info(
  93. "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
  94. name,
  95. "backward" if is_bw else "forward",
  96. )
  97. return fn
  98. def _fw_pre_hook(self, mod, input) -> None:
  99. name = self._get_mod_name(mod)
  100. self._get_append_fn(name, False)()
  101. args, _ = tree_flatten(input)
  102. tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
  103. if tensors:
  104. self._hooks.append(
  105. register_multi_grad_hook(tensors, self._get_pop_fn(name, True))
  106. )
  107. def _fw_post_hook(self, mod, input, output) -> None:
  108. name = self._get_mod_name(mod)
  109. self._get_pop_fn(name, False)()
  110. args, _ = tree_flatten(output)
  111. tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad]
  112. if tensors:
  113. self._hooks.append(
  114. register_multi_grad_hook(tensors, self._get_append_fn(name, True))
  115. )
  116. def __enter__(self):
  117. self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook)
  118. self._fw_post_handle = register_module_forward_hook(self._fw_post_hook)
  119. return self
  120. def __exit__(self, *args):
  121. self._fw_pre_handle.remove()
  122. self._fw_post_handle.remove()
  123. for hook in self._hooks:
  124. hook.remove()
  125. self._hooks.clear()