| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- from typing import Callable, Dict, List, Optional, Union, Tuple, Type
- import torch
- from torch import nn
- try:
- # NOTE we wrap torchvision fns to use timm leaf / no trace definitions
- from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
- from torchvision.models.feature_extraction import get_graph_node_names as _get_graph_node_names
- has_fx_feature_extraction = True
- except ImportError:
- has_fx_feature_extraction = False
- __all__ = [
- 'register_notrace_module',
- 'is_notrace_module',
- 'get_notrace_modules',
- 'register_notrace_function',
- 'is_notrace_function',
- 'get_notrace_functions',
- 'create_feature_extractor',
- 'get_graph_node_names',
- ]
- # modules to treat as leafs when tracing
- _leaf_modules = set()
- def register_notrace_module(module: Type[nn.Module]):
- """
- Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
- """
- _leaf_modules.add(module)
- return module
- def is_notrace_module(module: Type[nn.Module]):
- return module in _leaf_modules
- def get_notrace_modules():
- return list(_leaf_modules)
- # Functions we want to autowrap (treat them as leaves)
- _autowrap_functions = set()
- def register_notrace_function(name_or_fn):
- _autowrap_functions.add(name_or_fn)
- return name_or_fn
- def is_notrace_function(func: Callable):
- return func in _autowrap_functions
- def get_notrace_functions():
- return list(_autowrap_functions)
- def get_graph_node_names(model: nn.Module) -> Tuple[List[str], List[str]]:
- return _get_graph_node_names(
- model,
- tracer_kwargs={
- 'leaf_modules': list(_leaf_modules),
- 'autowrap_functions': list(_autowrap_functions)
- }
- )
- def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
- assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
- return _create_feature_extractor(
- model, return_nodes,
- tracer_kwargs={
- 'leaf_modules': list(_leaf_modules),
- 'autowrap_functions': list(_autowrap_functions)
- }
- )
|