tracer.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from collections.abc import Callable
  2. import torch
  3. from torch.ao.nn.intrinsic import _FusedModule
  4. from torch.fx._symbolic_trace import Tracer
  5. from torch.fx.proxy import Scope
  6. __all__ = [
  7. "QuantizationTracer",
  8. ]
  9. class ScopeContextManager(torch.fx.proxy.ScopeContextManager):
  10. def __init__(
  11. self, scope: Scope, current_module: torch.nn.Module, current_module_path: str
  12. ):
  13. super().__init__(scope, Scope(current_module_path, type(current_module)))
  14. class QuantizationTracer(Tracer):
  15. def __init__(
  16. self, skipped_module_names: list[str], skipped_module_classes: list[Callable]
  17. ):
  18. super().__init__()
  19. self.skipped_module_names = skipped_module_names
  20. self.skipped_module_classes = skipped_module_classes
  21. # NB: initialized the module_type of top level module to None
  22. # we are assuming people won't configure the model with the type of top level
  23. # module here, since people can use "" for global config
  24. # We can change this if there is a use case that configures
  25. # qconfig using top level module type
  26. self.scope = Scope("", None)
  27. self.record_stack_traces = True
  28. def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool:
  29. return (
  30. (
  31. (
  32. m.__module__.startswith("torch.nn")
  33. or m.__module__.startswith("torch.ao.nn")
  34. )
  35. and not isinstance(m, torch.nn.Sequential)
  36. )
  37. or module_qualified_name in self.skipped_module_names
  38. or type(m) in self.skipped_module_classes
  39. or isinstance(m, _FusedModule)
  40. )