_features_fx.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. """ PyTorch FX Based Feature Extraction Helpers
  2. Using https://pytorch.org/vision/stable/feature_extraction.html
  3. """
  4. from typing import Callable, Dict, List, Optional, Union, Tuple, Type
  5. import torch
  6. from torch import nn
  7. from timm.layers import (
  8. create_feature_extractor,
  9. get_graph_node_names,
  10. register_notrace_module,
  11. register_notrace_function,
  12. is_notrace_module,
  13. is_notrace_function,
  14. get_notrace_functions,
  15. get_notrace_modules,
  16. Format,
  17. )
  18. from ._features import _get_feature_info, _get_return_layers
  19. __all__ = [
  20. 'register_notrace_module',
  21. 'is_notrace_module',
  22. 'get_notrace_modules',
  23. 'register_notrace_function',
  24. 'is_notrace_function',
  25. 'get_notrace_functions',
  26. 'create_feature_extractor',
  27. 'get_graph_node_names',
  28. 'FeatureGraphNet',
  29. 'GraphExtractNet',
  30. ]
  31. class FeatureGraphNet(nn.Module):
  32. """ A FX Graph based feature extractor that works with the model feature_info metadata
  33. """
  34. return_dict: torch.jit.Final[bool]
  35. def __init__(
  36. self,
  37. model: nn.Module,
  38. out_indices: Tuple[int, ...],
  39. out_map: Optional[Dict] = None,
  40. output_fmt: str = 'NCHW',
  41. return_dict: bool = False,
  42. ):
  43. super().__init__()
  44. self.feature_info = _get_feature_info(model, out_indices)
  45. if out_map is not None:
  46. assert len(out_map) == len(out_indices)
  47. self.output_fmt = Format(output_fmt)
  48. return_nodes = _get_return_layers(self.feature_info, out_map)
  49. self.graph_module = create_feature_extractor(model, return_nodes)
  50. self.return_dict = return_dict
  51. def forward(self, x):
  52. out = self.graph_module(x)
  53. if self.return_dict:
  54. return out
  55. return list(out.values())
  56. class GraphExtractNet(nn.Module):
  57. """ A standalone feature extraction wrapper that maps dict -> list or single tensor
  58. NOTE:
  59. * one can use feature_extractor directly if dictionary output is desired
  60. * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
  61. metadata for builtin feature extraction mode
  62. * create_feature_extractor can be used directly if dictionary output is acceptable
  63. Args:
  64. model: model to extract features from
  65. return_nodes: node names to return features from (dict or list)
  66. squeeze_out: if only one output, and output in list format, flatten to single tensor
  67. return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
  68. """
  69. return_dict: torch.jit.Final[bool]
  70. def __init__(
  71. self,
  72. model: nn.Module,
  73. return_nodes: Union[Dict[str, str], List[str]],
  74. squeeze_out: bool = True,
  75. return_dict: bool = False,
  76. ):
  77. super().__init__()
  78. self.squeeze_out = squeeze_out
  79. self.graph_module = create_feature_extractor(model, return_nodes)
  80. self.return_dict = return_dict
  81. def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
  82. out = self.graph_module(x)
  83. if self.return_dict:
  84. return out
  85. out = list(out.values())
  86. return out[0] if self.squeeze_out and len(out) == 1 else out