_fx.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from typing import Callable, Dict, List, Optional, Union, Tuple, Type
  2. import torch
  3. from torch import nn
  4. try:
  5. # NOTE we wrap torchvision fns to use timm leaf / no trace definitions
  6. from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
  7. from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
  8. has_fx_feature_extraction = True
  9. except ImportError:
  10. has_fx_feature_extraction = False
  11. __all__ = [
  12. 'register_notrace_module',
  13. 'is_notrace_module',
  14. 'get_notrace_modules',
  15. 'register_notrace_function',
  16. 'is_notrace_function',
  17. 'get_notrace_functions',
  18. 'create_feature_extractor',
  19. 'get_graph_node_names',
  20. ]
  21. # modules to treat as leafs when tracing
  22. _leaf_modules = set()
  23. def register_notrace_module(module: Type[nn.Module]):
  24. """
  25. Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
  26. """
  27. _leaf_modules.add(module)
  28. return module
  29. def is_notrace_module(module: Type[nn.Module]):
  30. return module in _leaf_modules
  31. def get_notrace_modules():
  32. return list(_leaf_modules)
  33. # Functions we want to autowrap (treat them as leaves)
  34. _autowrap_functions = set()
  35. def register_notrace_function(name_or_fn):
  36. _autowrap_functions.add(name_or_fn)
  37. return name_or_fn
  38. def is_notrace_function(func: Callable):
  39. return func in _autowrap_functions
  40. def get_notrace_functions():
  41. return list(_autowrap_functions)
  42. def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
  43. return _get_graph_node_names(
  44. model,
  45. tracer_kwargs={
  46. 'leaf_modules': list(_leaf_modules),
  47. 'autowrap_functions': list(_autowrap_functions)
  48. }
  49. )
  50. def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
  51. assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
  52. return _create_feature_extractor(
  53. model, return_nodes,
  54. tracer_kwargs={
  55. 'leaf_modules': list(_leaf_modules),
  56. 'autowrap_functions': list(_autowrap_functions)
  57. }
  58. )