_features.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483
  1. """ PyTorch Feature Extraction Helpers
  2. A collection of classes, functions, modules to help extract features from models
  3. and provide a common interface for describing them.
  4. The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
  5. https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
  6. Hacked together by / Copyright 2020 Ross Wightman
  7. """
  8. from collections import OrderedDict, defaultdict
  9. from copy import deepcopy
  10. from functools import partial
  11. from typing import Dict, List, Optional, Sequence, Tuple, Union
  12. import torch
  13. import torch.nn as nn
  14. from timm.layers import Format, _assert
  15. from ._manipulate import checkpoint
  16. __all__ = [
  17. 'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet',
  18. 'feature_take_indices'
  19. ]
  20. def feature_take_indices(
  21. num_features: int,
  22. indices: Optional[Union[int, List[int]]] = None,
  23. as_set: bool = False,
  24. ) -> Tuple[List[int], int]:
  25. """ Determine the absolute feature indices to 'take' from.
  26. Note: This function can be called in forward() so must be torchscript compatible,
  27. which requires some incomplete typing and workaround hacks.
  28. Args:
  29. num_features: total number of features to select from
  30. indices: indices to select,
  31. None -> select all
  32. int -> select last n
  33. list/tuple of int -> return specified (-ve indices specify from end)
  34. as_set: return as a set
  35. Returns:
  36. List (or set) of absolute (from beginning) indices, Maximum index
  37. """
  38. if indices is None:
  39. indices = num_features # all features if None
  40. if isinstance(indices, int):
  41. # convert int -> last n indices
  42. _assert(0 < indices <= num_features, f'last-n ({indices}) is out of range (1 to {num_features})')
  43. take_indices = [num_features - indices + i for i in range(indices)]
  44. else:
  45. take_indices: List[int] = []
  46. for i in indices:
  47. idx = num_features + i if i < 0 else i
  48. _assert(0 <= idx < num_features, f'feature index {idx} is out of range (0 to {num_features - 1})')
  49. take_indices.append(idx)
  50. if not torch.jit.is_scripting() and as_set:
  51. return set(take_indices), max(take_indices)
  52. return take_indices, max(take_indices)
  53. def _out_indices_as_tuple(x: Union[int, Tuple[int, ...]]) -> Tuple[int, ...]:
  54. if isinstance(x, int):
  55. # if indices is an int, take last N features
  56. return tuple(range(-x, 0))
  57. return tuple(x)
  58. OutIndicesT = Union[int, Tuple[int, ...]]
  59. class FeatureInfo:
  60. def __init__(
  61. self,
  62. feature_info: List[Dict],
  63. out_indices: OutIndicesT,
  64. ):
  65. out_indices = _out_indices_as_tuple(out_indices)
  66. prev_reduction = 1
  67. for i, fi in enumerate(feature_info):
  68. # sanity check the mandatory fields, there may be additional fields depending on the model
  69. assert 'num_chs' in fi and fi['num_chs'] > 0
  70. assert 'reduction' in fi and fi['reduction'] >= prev_reduction
  71. prev_reduction = fi['reduction']
  72. assert 'module' in fi
  73. fi.setdefault('index', i)
  74. self.out_indices = out_indices
  75. self.info = feature_info
  76. def from_other(self, out_indices: OutIndicesT):
  77. out_indices = _out_indices_as_tuple(out_indices)
  78. return FeatureInfo(deepcopy(self.info), out_indices)
  79. def get(self, key: str, idx: Optional[Union[int, List[int]]] = None):
  80. """ Get value by key at specified index (indices)
  81. if idx == None, returns value for key at each output index
  82. if idx is an integer, return value for that feature module index (ignoring output indices)
  83. if idx is a list/tuple, return value for each module index (ignoring output indices)
  84. """
  85. if idx is None:
  86. return [self.info[i][key] for i in self.out_indices]
  87. if isinstance(idx, (tuple, list)):
  88. return [self.info[i][key] for i in idx]
  89. else:
  90. return self.info[idx][key]
  91. def get_dicts(self, keys: Optional[List[str]] = None, idx: Optional[Union[int, List[int]]] = None):
  92. """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
  93. """
  94. if idx is None:
  95. if keys is None:
  96. return [self.info[i] for i in self.out_indices]
  97. else:
  98. return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
  99. if isinstance(idx, (tuple, list)):
  100. return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
  101. else:
  102. return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
  103. def channels(self, idx: Optional[Union[int, List[int]]] = None):
  104. """ feature channels accessor
  105. """
  106. return self.get('num_chs', idx)
  107. def reduction(self, idx: Optional[Union[int, List[int]]] = None):
  108. """ feature reduction (output stride) accessor
  109. """
  110. return self.get('reduction', idx)
  111. def module_name(self, idx: Optional[Union[int, List[int]]] = None):
  112. """ feature module name accessor
  113. """
  114. return self.get('module', idx)
  115. def __getitem__(self, item):
  116. return self.info[item]
  117. def __len__(self):
  118. return len(self.info)
  119. class FeatureHooks:
  120. """ Feature Hook Helper
  121. This module helps with the setup and extraction of hooks for extracting features from
  122. internal nodes in a model by node name.
  123. FIXME This works well in eager Python but needs redesign for torchscript.
  124. """
  125. def __init__(
  126. self,
  127. hooks: Sequence[Union[str, Dict]],
  128. named_modules: dict,
  129. out_map: Sequence[Union[int, str]] = None,
  130. default_hook_type: str = 'forward',
  131. ):
  132. # setup feature hooks
  133. self._feature_outputs = defaultdict(OrderedDict)
  134. self._handles = []
  135. modules = {k: v for k, v in named_modules}
  136. for i, h in enumerate(hooks):
  137. hook_name = h if isinstance(h, str) else h['module']
  138. m = modules[hook_name]
  139. hook_id = out_map[i] if out_map else hook_name
  140. hook_fn = partial(self._collect_output_hook, hook_id)
  141. hook_type = default_hook_type
  142. if isinstance(h, dict):
  143. hook_type = h.get('hook_type', default_hook_type)
  144. if hook_type == 'forward_pre':
  145. handle = m.register_forward_pre_hook(hook_fn)
  146. elif hook_type == 'forward':
  147. handle = m.register_forward_hook(hook_fn)
  148. else:
  149. assert False, "Unsupported hook type"
  150. self._handles.append(handle)
  151. def _collect_output_hook(self, hook_id, *args):
  152. x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
  153. if isinstance(x, tuple):
  154. x = x[0] # unwrap input tuple
  155. self._feature_outputs[x.device][hook_id] = x
  156. def get_output(self, device) -> Dict[str, torch.tensor]:
  157. output = self._feature_outputs[device]
  158. self._feature_outputs[device] = OrderedDict() # clear after reading
  159. return output
  160. def _module_list(module, flatten_sequential=False):
  161. # a yield/iter would be better for this but wouldn't be compatible with torchscript
  162. ml = []
  163. for name, module in module.named_children():
  164. if flatten_sequential and isinstance(module, nn.Sequential):
  165. # first level of Sequential containers is flattened into containing model
  166. for child_name, child_module in module.named_children():
  167. combined = [name, child_name]
  168. ml.append(('_'.join(combined), '.'.join(combined), child_module))
  169. else:
  170. ml.append((name, name, module))
  171. return ml
  172. def _get_feature_info(net, out_indices: OutIndicesT):
  173. feature_info = getattr(net, 'feature_info')
  174. if isinstance(feature_info, FeatureInfo):
  175. return feature_info.from_other(out_indices)
  176. elif isinstance(feature_info, (list, tuple)):
  177. return FeatureInfo(net.feature_info, out_indices)
  178. else:
  179. assert False, "Provided feature_info is not valid"
  180. def _get_return_layers(feature_info, out_map):
  181. module_names = feature_info.module_name()
  182. return_layers = {}
  183. for i, name in enumerate(module_names):
  184. return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
  185. return return_layers
  186. class FeatureDictNet(nn.ModuleDict):
  187. """ Feature extractor with OrderedDict return
  188. Wrap a model and extract features as specified by the out indices, the network is
  189. partially re-built from contained modules.
  190. There is a strong assumption that the modules have been registered into the model in the same
  191. order as they are used. There should be no reuse of the same nn.Module more than once, including
  192. trivial modules like `self.relu = nn.ReLU`.
  193. Only submodules that are directly assigned to the model class (`model.feature1`) or at most
  194. one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
  195. All Sequential containers that are directly assigned to the original model will have their
  196. modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
  197. """
  198. def __init__(
  199. self,
  200. model: nn.Module,
  201. out_indices: OutIndicesT = (0, 1, 2, 3, 4),
  202. out_map: Sequence[Union[int, str]] = None,
  203. output_fmt: str = 'NCHW',
  204. feature_concat: bool = False,
  205. flatten_sequential: bool = False,
  206. ):
  207. """
  208. Args:
  209. model: Model from which to extract features.
  210. out_indices: Output indices of the model features to extract.
  211. out_map: Return id mapping for each output index, otherwise str(index) is used.
  212. feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
  213. first element e.g. `x[0]`
  214. flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
  215. """
  216. super().__init__()
  217. self.feature_info = _get_feature_info(model, out_indices)
  218. self.output_fmt = Format(output_fmt)
  219. self.concat = feature_concat
  220. self.grad_checkpointing = False
  221. self.return_layers = {}
  222. return_layers = _get_return_layers(self.feature_info, out_map)
  223. modules = _module_list(model, flatten_sequential=flatten_sequential)
  224. remaining = set(return_layers.keys())
  225. layers = OrderedDict()
  226. for new_name, old_name, module in modules:
  227. layers[new_name] = module
  228. if old_name in remaining:
  229. # return id has to be consistently str type for torchscript
  230. self.return_layers[new_name] = str(return_layers[old_name])
  231. remaining.remove(old_name)
  232. if not remaining:
  233. break
  234. assert not remaining and len(self.return_layers) == len(return_layers), \
  235. f'Return layers ({remaining}) are not present in model'
  236. self.update(layers)
  237. def set_grad_checkpointing(self, enable: bool = True):
  238. self.grad_checkpointing = enable
  239. def _collect(self, x) -> (Dict[str, torch.Tensor]):
  240. out = OrderedDict()
  241. for i, (name, module) in enumerate(self.items()):
  242. if self.grad_checkpointing and not torch.jit.is_scripting():
  243. # Skipping checkpoint of first module because need a gradient at input
  244. # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
  245. # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
  246. first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
  247. x = module(x) if first_or_last_module else checkpoint(module, x)
  248. else:
  249. x = module(x)
  250. if name in self.return_layers:
  251. out_id = self.return_layers[name]
  252. if isinstance(x, (tuple, list)):
  253. # If model tap is a tuple or list, concat or select first element
  254. # FIXME this may need to be more generic / flexible for some nets
  255. out[out_id] = torch.cat(x, 1) if self.concat else x[0]
  256. else:
  257. out[out_id] = x
  258. return out
  259. def forward(self, x) -> Dict[str, torch.Tensor]:
  260. return self._collect(x)
  261. class FeatureListNet(FeatureDictNet):
  262. """ Feature extractor with list return
  263. A specialization of FeatureDictNet that always returns features as a list (values() of dict).
  264. """
  265. def __init__(
  266. self,
  267. model: nn.Module,
  268. out_indices: OutIndicesT = (0, 1, 2, 3, 4),
  269. output_fmt: str = 'NCHW',
  270. feature_concat: bool = False,
  271. flatten_sequential: bool = False,
  272. ):
  273. """
  274. Args:
  275. model: Model from which to extract features.
  276. out_indices: Output indices of the model features to extract.
  277. feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
  278. first element e.g. `x[0]`
  279. flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
  280. """
  281. super().__init__(
  282. model,
  283. out_indices=out_indices,
  284. output_fmt=output_fmt,
  285. feature_concat=feature_concat,
  286. flatten_sequential=flatten_sequential,
  287. )
  288. def forward(self, x) -> (List[torch.Tensor]):
  289. return list(self._collect(x).values())
  290. class FeatureHookNet(nn.ModuleDict):
  291. """ FeatureHookNet
  292. Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
  293. If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
  294. network in any way.
  295. If `no_rewrite` is False, the model will be re-written as in the
  296. FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
  297. FIXME this does not currently work with Torchscript, see FeatureHooks class
  298. """
  299. def __init__(
  300. self,
  301. model: nn.Module,
  302. out_indices: OutIndicesT = (0, 1, 2, 3, 4),
  303. out_map: Optional[Sequence[Union[int, str]]] = None,
  304. return_dict: bool = False,
  305. output_fmt: str = 'NCHW',
  306. no_rewrite: Optional[bool] = None,
  307. flatten_sequential: bool = False,
  308. default_hook_type: str = 'forward',
  309. ):
  310. """
  311. Args:
  312. model: Model from which to extract features.
  313. out_indices: Output indices of the model features to extract.
  314. out_map: Return id mapping for each output index, otherwise str(index) is used.
  315. return_dict: Output features as a dict.
  316. no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
  317. flatten_sequential arg must also be False if this is set True.
  318. flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
  319. default_hook_type: The default hook type to use if not specified in model.feature_info.
  320. """
  321. super().__init__()
  322. assert not torch.jit.is_scripting()
  323. self.feature_info = _get_feature_info(model, out_indices)
  324. self.return_dict = return_dict
  325. self.output_fmt = Format(output_fmt)
  326. self.grad_checkpointing = False
  327. if no_rewrite is None:
  328. no_rewrite = not flatten_sequential
  329. layers = OrderedDict()
  330. hooks = []
  331. if no_rewrite:
  332. assert not flatten_sequential
  333. if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
  334. model.reset_classifier(0)
  335. layers['body'] = model
  336. hooks.extend(self.feature_info.get_dicts())
  337. else:
  338. modules = _module_list(model, flatten_sequential=flatten_sequential)
  339. remaining = {
  340. f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
  341. for f in self.feature_info.get_dicts()
  342. }
  343. for new_name, old_name, module in modules:
  344. layers[new_name] = module
  345. for fn, fm in module.named_modules(prefix=old_name):
  346. if fn in remaining:
  347. hooks.append(dict(module=fn, hook_type=remaining[fn]))
  348. del remaining[fn]
  349. if not remaining:
  350. break
  351. assert not remaining, f'Return layers ({remaining}) are not present in model'
  352. self.update(layers)
  353. self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
  354. def set_grad_checkpointing(self, enable: bool = True):
  355. self.grad_checkpointing = enable
  356. def forward(self, x):
  357. for i, (name, module) in enumerate(self.items()):
  358. if self.grad_checkpointing and not torch.jit.is_scripting():
  359. # Skipping checkpoint of first module because need a gradient at input
  360. # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
  361. # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
  362. first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
  363. x = module(x) if first_or_last_module else checkpoint(module, x)
  364. else:
  365. x = module(x)
  366. out = self.hooks.get_output(x.device)
  367. return out if self.return_dict else list(out.values())
  368. class FeatureGetterNet(nn.ModuleDict):
  369. """ FeatureGetterNet
  370. Wrap models with a feature getter method, like 'get_intermediate_layers'
  371. """
  372. def __init__(
  373. self,
  374. model: nn.Module,
  375. out_indices: OutIndicesT = 4,
  376. out_map: Optional[Sequence[Union[int, str]]] = None,
  377. return_dict: bool = False,
  378. output_fmt: str = 'NCHW',
  379. norm: bool = False,
  380. prune: bool = True,
  381. ):
  382. """
  383. Args:
  384. model: Model to wrap.
  385. out_indices: Indices of features to extract.
  386. out_map: Remap feature names for dict output (WIP, not supported).
  387. return_dict: Return features as dictionary instead of list (WIP, not supported).
  388. norm: Apply final model norm to all output features (if possible).
  389. """
  390. super().__init__()
  391. if prune and hasattr(model, 'prune_intermediate_layers'):
  392. # replace out_indices after they've been normalized, -ve indices will be invalid after prune
  393. out_indices = model.prune_intermediate_layers(
  394. out_indices,
  395. prune_norm=not norm,
  396. )
  397. self.feature_info = _get_feature_info(model, out_indices)
  398. self.model = model
  399. self.out_indices = out_indices
  400. self.out_map = out_map
  401. self.return_dict = return_dict
  402. self.output_fmt = Format(output_fmt)
  403. self.norm = norm
  404. def forward(self, x):
  405. features = self.model.forward_intermediates(
  406. x,
  407. indices=self.out_indices,
  408. norm=self.norm,
  409. output_fmt=self.output_fmt,
  410. intermediates_only=True,
  411. )
  412. return features