| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215 |
- # mypy: allow-untyped-defs
- import copy
- import torch.nn as nn
- # for backward compatibility
- from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401
- fuse_conv_bn,
- fuse_conv_bn_relu,
- get_fuser_method,
- )
- from torch.nn.utils.parametrize import type_before_parametrizations
- __all__ = [
- "fuse_known_modules",
- "fuse_modules",
- "fuse_modules_qat",
- ]
- # Generalization of getattr
- def _get_module(model, submodule_key):
- tokens = submodule_key.split(".")
- cur_mod = model
- for s in tokens:
- cur_mod = getattr(cur_mod, s)
- return cur_mod
- # Generalization of setattr
- def _set_module(model, submodule_key, module):
- tokens = submodule_key.split(".")
- sub_tokens = tokens[:-1]
- cur_mod = model
- for s in sub_tokens:
- cur_mod = getattr(cur_mod, s)
- setattr(cur_mod, tokens[-1], module)
- def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
- r"""Return a list of known fuse modules.
- Returns a list of modules that fuses the operations specified
- in the input module list.
- Fuses only the following sequence of modules:
- conv, bn
- conv, bn, relu
- conv, relu
- linear, bn
- linear, relu
- For these sequences, the first element in the output module list performs
- the fused operation. The rest of the elements are set to nn.Identity()
- """
- types = tuple(type_before_parametrizations(m) for m in mod_list)
- fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
- if fuser_method is None:
- raise NotImplementedError(f"Cannot fuse modules: {types}")
- new_mod: list[nn.Module | None] = [None] * len(mod_list)
- fused = fuser_method(is_qat, *mod_list)
- # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
- # Move pre forward hooks of the base module to resulting fused module
- for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
- fused.register_forward_pre_hook(pre_hook_fn)
- mod_list[0]._forward_pre_hooks.clear()
- # Move post forward hooks of the last module to resulting fused module
- for hook_fn in mod_list[-1]._forward_hooks.values():
- fused.register_forward_hook(hook_fn)
- mod_list[-1]._forward_hooks.clear()
- new_mod[0] = fused
- for i in range(1, len(mod_list)):
- identity = nn.Identity()
- identity.training = mod_list[0].training
- new_mod[i] = identity
- return new_mod
- def _fuse_modules_helper(
- model,
- modules_to_fuse,
- is_qat,
- fuser_func=fuse_known_modules,
- fuse_custom_config_dict=None,
- ):
- if fuse_custom_config_dict is None:
- fuse_custom_config_dict = {}
- additional_fuser_method_mapping = fuse_custom_config_dict.get(
- "additional_fuser_method_mapping", {}
- )
- mod_list = [_get_module(model, item) for item in modules_to_fuse]
- # Fuse list of modules
- new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
- # Replace original module list with fused module list
- for i, item in enumerate(modules_to_fuse):
- _set_module(model, item, new_mod_list[i])
- def _fuse_modules(
- model,
- modules_to_fuse,
- is_qat,
- inplace=False,
- fuser_func=fuse_known_modules,
- fuse_custom_config_dict=None,
- ):
- if not inplace:
- model = copy.deepcopy(model)
- if all(isinstance(module_element, str) for module_element in modules_to_fuse):
- # Handle case of modules_to_fuse being a list
- _fuse_modules_helper(
- model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict
- )
- else:
- # Handle case of modules_to_fuse being a list of lists
- for module_list in modules_to_fuse:
- _fuse_modules_helper(
- model, module_list, is_qat, fuser_func, fuse_custom_config_dict
- )
- return model
- def fuse_modules(
- model,
- modules_to_fuse,
- inplace=False,
- fuser_func=fuse_known_modules,
- fuse_custom_config_dict=None,
- ):
- r"""Fuse a list of modules into a single module.
- Fuses only the following sequence of modules:
- conv, bn
- conv, bn, relu
- conv, relu
- linear, relu
- bn, relu
- All other sequences are left unchanged.
- For these sequences, replaces the first item in the list
- with the fused module, replacing the rest of the modules
- with identity.
- Args:
- model: Model containing the modules to be fused
- modules_to_fuse: list of list of module names to fuse. Can also be a list
- of strings if there is only a single list of modules to fuse.
- inplace: bool specifying if fusion happens in place on the model, by default
- a new model is returned
- fuser_func: Function that takes in a list of modules and outputs a list of fused modules
- of the same length. For example,
- fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
- Defaults to torch.ao.quantization.fuse_known_modules
- `fuse_custom_config_dict`: custom configuration for fusion
- .. code-block:: python
- # Example of fuse_custom_config_dict
- fuse_custom_config_dict = {
- # Additional fuser_method mapping
- "additional_fuser_method_mapping": {
- (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
- },
- }
- Returns:
- model with fused modules. A new copy is created if inplace=True.
- Examples::
- >>> # xdoctest: +SKIP
- >>> m = M().eval()
- >>> # m is a module containing the sub-modules below
- >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
- >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
- >>> output = fused_m(input)
- >>> m = M().eval()
- >>> # Alternately provide a single list of modules to fuse
- >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
- >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
- >>> output = fused_m(input)
- """
- return _fuse_modules(
- model,
- modules_to_fuse,
- is_qat=False,
- inplace=inplace,
- fuser_func=fuser_func,
- fuse_custom_config_dict=fuse_custom_config_dict,
- )
- def fuse_modules_qat(
- model,
- modules_to_fuse,
- inplace=False,
- fuser_func=fuse_known_modules,
- fuse_custom_config_dict=None,
- ):
- """QAT version for `fuse_modules`."""
- return _fuse_modules(
- model,
- modules_to_fuse,
- is_qat=True,
- inplace=inplace,
- fuser_func=fuser_func,
- fuse_custom_config_dict=fuse_custom_config_dict,
- )
|