mutation_guard.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. """Mutation tracking and dynamic module detection system for Dynamo.
  2. This module provides mechanisms to track and respond to mutations in PyTorch modules
  3. and detect dynamically created or modified modules.
  4. Key components:
  5. - MutationTracker: Tracks mutations to objects and invalidates associated cached code
  6. - GenerationTracker: Tracks module creation timing to identify dynamic instances
  7. - Patching system for nn.Module to detect mutations and dynamic creation
  8. The system ensures that Dynamo's optimizations remain valid by detecting and responding
  9. to runtime changes in module state and structure.
  10. """
  11. import functools
  12. import weakref
  13. from collections.abc import MutableMapping
  14. from typing import Any
  15. import torch.nn
  16. from torch.nn import Module
  17. from . import config
  18. from .utils import ExactWeakKeyDictionary, nn_module_has_global_hooks
  19. unpatched_nn_module_init = torch.nn.Module.__init__
  20. class MutationTracker:
  21. db: ExactWeakKeyDictionary = ExactWeakKeyDictionary()
  22. def __init__(self) -> None:
  23. self.mutation_count: int = 0
  24. self.watchers: list[weakref.ReferenceType[Any]] = []
  25. def on_mutation(self, name: str) -> None:
  26. self.mutation_count += 1
  27. tmp = self.watchers
  28. self.watchers = []
  29. for ref in tmp:
  30. guarded = ref()
  31. if guarded is not None:
  32. guarded.invalidate(ref)
  33. def track(self, guarded_code: Any) -> None:
  34. self.watchers.append(weakref.ref(guarded_code))
  35. def watch(obj: Any, guarded_code: Any) -> None:
  36. """invalidate guarded_code when obj is mutated"""
  37. ensure_patched(type(obj))
  38. if obj not in MutationTracker.db:
  39. MutationTracker.db[obj] = MutationTracker()
  40. tracker = MutationTracker.db[obj]
  41. tracker.track(guarded_code)
  42. def ensure_patched(cls: Any) -> None:
  43. if getattr(cls, "___needs_mutation_patch", True):
  44. cls.___needs_mutation_patch = False
  45. original_setattr = cls.__setattr__
  46. @functools.wraps(original_setattr)
  47. def custom_setattr(self: Any, key: str, value: Any) -> None:
  48. try:
  49. MutationTracker.db[self].on_mutation(key)
  50. except KeyError:
  51. pass
  52. return original_setattr(self, key, value)
  53. cls.__setattr__ = custom_setattr
  54. class GenerationTracker:
  55. generation: int = 0
  56. dynamic_classes: ExactWeakKeyDictionary = ExactWeakKeyDictionary()
  57. generation_values: ExactWeakKeyDictionary = ExactWeakKeyDictionary()
  58. @classmethod
  59. def tag(cls, obj: Any) -> None:
  60. cls.generation_values[obj] = cls.generation
  61. @staticmethod
  62. def mark_class_dynamic(cls: type[torch.nn.Module]) -> None:
  63. assert issubclass(cls, torch.nn.Module)
  64. GenerationTracker.dynamic_classes[cls] = True
  65. @classmethod
  66. def get_generation_value(cls, obj: Any) -> int:
  67. if obj not in cls.generation_values:
  68. return -1
  69. return cls.generation_values[obj]
  70. @classmethod
  71. def check(cls, obj: Any) -> bool:
  72. return (
  73. obj in cls.generation_values
  74. and cls.generation_values[obj] == cls.generation
  75. )
  76. @classmethod
  77. def clear(cls) -> None:
  78. cls.generation = 0
  79. cls.dynamic_classes = ExactWeakKeyDictionary()
  80. cls.generation_values = ExactWeakKeyDictionary()
  81. def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool:
  82. """Check for nn.Modules() created dynamically or mutated"""
  83. if isinstance(obj, torch.nn.Module) and (
  84. "forward" in obj.__dict__ or isinstance(obj, (dict, MutableMapping))
  85. ):
  86. # A monkey patched `.forward` indicates something wacky is going on
  87. # Similarly a nn module also subclassed as a dict is unusual.
  88. return True
  89. if hasattr(obj, "torchdynamo_force_dynamic"):
  90. return obj.torchdynamo_force_dynamic
  91. if (
  92. isinstance(obj, torch.nn.Module)
  93. and config.inline_inbuilt_nn_modules
  94. and (not is_export or config.install_free_tensors)
  95. ):
  96. return True
  97. if isinstance(obj, torch.nn.Module) and nn_module_has_global_hooks():
  98. return True
  99. dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
  100. obj
  101. )
  102. return dyn
  103. def install_generation_tagging_init() -> None:
  104. """
  105. Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
  106. so we can detect nn.Module instances created dynamically inside forward methods.
  107. """
  108. if getattr(Module, "___needs_generation_tag_patch", True):
  109. init = Module.__init__
  110. def patched_init(self: Module, *args: Any, **kwargs: Any) -> None:
  111. init(self, *args, **kwargs)
  112. GenerationTracker.tag(self)
  113. Module.__init__ = patched_init # type: ignore[method-assign]
  114. setstate = Module.__setstate__
  115. def patched_setstate(self: Module, state: Any) -> None:
  116. setstate(self, state)
  117. GenerationTracker.tag(self)
  118. Module.__setstate__ = patched_setstate # type: ignore[method-assign]
  119. Module.___needs_generation_tag_patch = False # type: ignore[attr-defined]
  120. GenerationTracker.generation += 1