attention_extract.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import fnmatch
  2. import re
  3. from collections import OrderedDict
  4. from typing import Union, Optional, List
  5. import torch
  6. class AttentionExtract(torch.nn.Module):
  7. # defaults should cover a significant number of timm models with attention maps.
  8. default_node_names = ['*attn.softmax']
  9. default_module_names = ['*attn_drop']
  10. def __init__(
  11. self,
  12. model: Union[torch.nn.Module],
  13. names: Optional[List[str]] = None,
  14. mode: str = 'eval',
  15. method: str = 'fx',
  16. hook_type: str = 'forward',
  17. use_regex: bool = False,
  18. ):
  19. """ Extract attention maps (or other activations) from a model by name.
  20. Args:
  21. model: Instantiated model to extract from.
  22. names: List of concrete or wildcard names to extract. Names are nodes for fx and modules for hooks.
  23. mode: 'train' or 'eval' model mode.
  24. method: 'fx' or 'hook' extraction method.
  25. hook_type: 'forward' or 'forward_pre' hooks used.
  26. use_regex: Use regex instead of fnmatch
  27. """
  28. super().__init__()
  29. assert mode in ('train', 'eval')
  30. if mode == 'train':
  31. model = model.train()
  32. else:
  33. model = model.eval()
  34. assert method in ('fx', 'hook')
  35. if method == 'fx':
  36. # names are activation node names
  37. from timm.models._features_fx import get_graph_node_names, GraphExtractNet
  38. node_names = get_graph_node_names(model)[0 if mode == 'train' else 1]
  39. names = names or self.default_node_names
  40. if use_regex:
  41. regexes = [re.compile(r) for r in names]
  42. matched = [g for g in node_names if any([r.match(g) for r in regexes])]
  43. else:
  44. matched = [g for g in node_names if any([fnmatch.fnmatch(g, n) for n in names])]
  45. if not matched:
  46. raise RuntimeError(f'No node names found matching {names}.')
  47. self.model = GraphExtractNet(model, matched, return_dict=True)
  48. self.hooks = None
  49. else:
  50. # names are module names
  51. assert hook_type in ('forward', 'forward_pre')
  52. from timm.models._features import FeatureHooks
  53. module_names = [n for n, m in model.named_modules()]
  54. names = names or self.default_module_names
  55. if use_regex:
  56. regexes = [re.compile(r) for r in names]
  57. matched = [m for m in module_names if any([r.match(m) for r in regexes])]
  58. else:
  59. matched = [m for m in module_names if any([fnmatch.fnmatch(m, n) for n in names])]
  60. if not matched:
  61. raise RuntimeError(f'No module names found matching {names}.')
  62. self.model = model
  63. self.hooks = FeatureHooks(matched, model.named_modules(), default_hook_type=hook_type)
  64. self.names = matched
  65. self.mode = mode
  66. self.method = method
  67. def forward(self, x):
  68. if self.hooks is not None:
  69. self.model(x)
  70. output = self.hooks.get_output(device=x.device)
  71. else:
  72. output = self.model(x)
  73. return output