fuse_modules.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import torch.nn as nn
  4. # for backward compatibility
  5. from torch.ao.quantization.fuser_method_mappings import ( # noqa: F401 # noqa: F401
  6. fuse_conv_bn,
  7. fuse_conv_bn_relu,
  8. get_fuser_method,
  9. )
  10. from torch.nn.utils.parametrize import type_before_parametrizations
  11. __all__ = [
  12. "fuse_known_modules",
  13. "fuse_modules",
  14. "fuse_modules_qat",
  15. ]
  16. # Generalization of getattr
  17. def _get_module(model, submodule_key):
  18. tokens = submodule_key.split(".")
  19. cur_mod = model
  20. for s in tokens:
  21. cur_mod = getattr(cur_mod, s)
  22. return cur_mod
  23. # Generalization of setattr
  24. def _set_module(model, submodule_key, module):
  25. tokens = submodule_key.split(".")
  26. sub_tokens = tokens[:-1]
  27. cur_mod = model
  28. for s in sub_tokens:
  29. cur_mod = getattr(cur_mod, s)
  30. setattr(cur_mod, tokens[-1], module)
  31. def fuse_known_modules(mod_list, is_qat, additional_fuser_method_mapping=None):
  32. r"""Return a list of known fuse modules.
  33. Returns a list of modules that fuses the operations specified
  34. in the input module list.
  35. Fuses only the following sequence of modules:
  36. conv, bn
  37. conv, bn, relu
  38. conv, relu
  39. linear, bn
  40. linear, relu
  41. For these sequences, the first element in the output module list performs
  42. the fused operation. The rest of the elements are set to nn.Identity()
  43. """
  44. types = tuple(type_before_parametrizations(m) for m in mod_list)
  45. fuser_method = get_fuser_method(types, additional_fuser_method_mapping)
  46. if fuser_method is None:
  47. raise NotImplementedError(f"Cannot fuse modules: {types}")
  48. new_mod: list[nn.Module | None] = [None] * len(mod_list)
  49. fused = fuser_method(is_qat, *mod_list)
  50. # NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
  51. # Move pre forward hooks of the base module to resulting fused module
  52. for pre_hook_fn in mod_list[0]._forward_pre_hooks.values():
  53. fused.register_forward_pre_hook(pre_hook_fn)
  54. mod_list[0]._forward_pre_hooks.clear()
  55. # Move post forward hooks of the last module to resulting fused module
  56. for hook_fn in mod_list[-1]._forward_hooks.values():
  57. fused.register_forward_hook(hook_fn)
  58. mod_list[-1]._forward_hooks.clear()
  59. new_mod[0] = fused
  60. for i in range(1, len(mod_list)):
  61. identity = nn.Identity()
  62. identity.training = mod_list[0].training
  63. new_mod[i] = identity
  64. return new_mod
  65. def _fuse_modules_helper(
  66. model,
  67. modules_to_fuse,
  68. is_qat,
  69. fuser_func=fuse_known_modules,
  70. fuse_custom_config_dict=None,
  71. ):
  72. if fuse_custom_config_dict is None:
  73. fuse_custom_config_dict = {}
  74. additional_fuser_method_mapping = fuse_custom_config_dict.get(
  75. "additional_fuser_method_mapping", {}
  76. )
  77. mod_list = [_get_module(model, item) for item in modules_to_fuse]
  78. # Fuse list of modules
  79. new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
  80. # Replace original module list with fused module list
  81. for i, item in enumerate(modules_to_fuse):
  82. _set_module(model, item, new_mod_list[i])
  83. def _fuse_modules(
  84. model,
  85. modules_to_fuse,
  86. is_qat,
  87. inplace=False,
  88. fuser_func=fuse_known_modules,
  89. fuse_custom_config_dict=None,
  90. ):
  91. if not inplace:
  92. model = copy.deepcopy(model)
  93. if all(isinstance(module_element, str) for module_element in modules_to_fuse):
  94. # Handle case of modules_to_fuse being a list
  95. _fuse_modules_helper(
  96. model, modules_to_fuse, is_qat, fuser_func, fuse_custom_config_dict
  97. )
  98. else:
  99. # Handle case of modules_to_fuse being a list of lists
  100. for module_list in modules_to_fuse:
  101. _fuse_modules_helper(
  102. model, module_list, is_qat, fuser_func, fuse_custom_config_dict
  103. )
  104. return model
  105. def fuse_modules(
  106. model,
  107. modules_to_fuse,
  108. inplace=False,
  109. fuser_func=fuse_known_modules,
  110. fuse_custom_config_dict=None,
  111. ):
  112. r"""Fuse a list of modules into a single module.
  113. Fuses only the following sequence of modules:
  114. conv, bn
  115. conv, bn, relu
  116. conv, relu
  117. linear, relu
  118. bn, relu
  119. All other sequences are left unchanged.
  120. For these sequences, replaces the first item in the list
  121. with the fused module, replacing the rest of the modules
  122. with identity.
  123. Args:
  124. model: Model containing the modules to be fused
  125. modules_to_fuse: list of list of module names to fuse. Can also be a list
  126. of strings if there is only a single list of modules to fuse.
  127. inplace: bool specifying if fusion happens in place on the model, by default
  128. a new model is returned
  129. fuser_func: Function that takes in a list of modules and outputs a list of fused modules
  130. of the same length. For example,
  131. fuser_func([convModule, BNModule]) returns the list [ConvBNModule, nn.Identity()]
  132. Defaults to torch.ao.quantization.fuse_known_modules
  133. `fuse_custom_config_dict`: custom configuration for fusion
  134. .. code-block:: python
  135. # Example of fuse_custom_config_dict
  136. fuse_custom_config_dict = {
  137. # Additional fuser_method mapping
  138. "additional_fuser_method_mapping": {
  139. (torch.nn.Conv2d, torch.nn.BatchNorm2d): fuse_conv_bn
  140. },
  141. }
  142. Returns:
  143. model with fused modules. A new copy is created if inplace=True.
  144. Examples::
  145. >>> # xdoctest: +SKIP
  146. >>> m = M().eval()
  147. >>> # m is a module containing the sub-modules below
  148. >>> modules_to_fuse = [ ['conv1', 'bn1', 'relu1'], ['submodule.conv', 'submodule.relu']]
  149. >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
  150. >>> output = fused_m(input)
  151. >>> m = M().eval()
  152. >>> # Alternately provide a single list of modules to fuse
  153. >>> modules_to_fuse = ['conv1', 'bn1', 'relu1']
  154. >>> fused_m = torch.ao.quantization.fuse_modules(m, modules_to_fuse)
  155. >>> output = fused_m(input)
  156. """
  157. return _fuse_modules(
  158. model,
  159. modules_to_fuse,
  160. is_qat=False,
  161. inplace=inplace,
  162. fuser_func=fuser_func,
  163. fuse_custom_config_dict=fuse_custom_config_dict,
  164. )
  165. def fuse_modules_qat(
  166. model,
  167. modules_to_fuse,
  168. inplace=False,
  169. fuser_func=fuse_known_modules,
  170. fuse_custom_config_dict=None,
  171. ):
  172. """QAT version for `fuse_modules`."""
  173. return _fuse_modules(
  174. model,
  175. modules_to_fuse,
  176. is_qat=True,
  177. inplace=inplace,
  178. fuser_func=fuser_func,
  179. fuse_custom_config_dict=fuse_custom_config_dict,
  180. )