_equalize.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # mypy: allow-untyped-defs
  2. import copy
  3. from itertools import chain
  4. from typing import Any
  5. import torch
  6. __all__ = [
  7. "set_module_weight",
  8. "set_module_bias",
  9. "has_bias",
  10. "get_module_weight",
  11. "get_module_bias",
  12. "max_over_ndim",
  13. "min_over_ndim",
  14. "channel_range",
  15. "get_name_by_module",
  16. "cross_layer_equalization",
  17. "process_paired_modules_list_to_name",
  18. "expand_groups_in_paired_modules_list",
  19. "equalize",
  20. "converged",
  21. ]
  22. _supported_types = {torch.nn.Conv2d, torch.nn.Linear, torch.nn.Conv1d}
  23. _supported_intrinsic_types = {
  24. torch.ao.nn.intrinsic.ConvReLU2d,
  25. torch.ao.nn.intrinsic.LinearReLU,
  26. torch.ao.nn.intrinsic.ConvReLU1d,
  27. }
  28. _all_supported_types = _supported_types.union(_supported_intrinsic_types)
  29. def set_module_weight(module, weight) -> None:
  30. if type(module) in _supported_types:
  31. module.weight = torch.nn.Parameter(weight)
  32. else:
  33. module[0].weight = torch.nn.Parameter(weight)
  34. def set_module_bias(module, bias) -> None:
  35. if type(module) in _supported_types:
  36. module.bias = torch.nn.Parameter(bias)
  37. else:
  38. module[0].bias = torch.nn.Parameter(bias)
  39. def has_bias(module) -> bool:
  40. if type(module) in _supported_types:
  41. return module.bias is not None
  42. else:
  43. return module[0].bias is not None
  44. def get_module_weight(module):
  45. if type(module) in _supported_types:
  46. return module.weight
  47. else:
  48. return module[0].weight
  49. def get_module_bias(module):
  50. if type(module) in _supported_types:
  51. return module.bias
  52. else:
  53. return module[0].bias
  54. def max_over_ndim(input, axis_list, keepdim=False):
  55. """Apply 'torch.max' over the given axes."""
  56. axis_list.sort(reverse=True)
  57. for axis in axis_list:
  58. input, _ = input.max(axis, keepdim)
  59. return input
  60. def min_over_ndim(input, axis_list, keepdim=False):
  61. """Apply 'torch.min' over the given axes."""
  62. axis_list.sort(reverse=True)
  63. for axis in axis_list:
  64. input, _ = input.min(axis, keepdim)
  65. return input
  66. def channel_range(input, axis=0):
  67. """Find the range of weights associated with a specific channel."""
  68. size_of_tensor_dim = input.ndim
  69. axis_list = list(range(size_of_tensor_dim))
  70. axis_list.remove(axis)
  71. mins = min_over_ndim(input, axis_list)
  72. maxs = max_over_ndim(input, axis_list)
  73. if mins.size(0) != input.size(axis):
  74. raise AssertionError(
  75. "Dimensions of resultant channel range does not match size of requested axis"
  76. )
  77. return maxs - mins
  78. def get_name_by_module(model, module):
  79. """Get the name of a module within a model.
  80. Args:
  81. model: a model (nn.module) that equalization is to be applied on
  82. module: a module within the model
  83. Returns:
  84. name: the name of the module within the model
  85. """
  86. for name, m in model.named_modules():
  87. if m is module:
  88. return name
  89. raise ValueError("module is not in the model")
  90. def cross_layer_equalization(module1, module2, output_axis=0, input_axis=1):
  91. """Scale the range of Tensor1.output to equal Tensor2.input.
  92. Given two adjacent tensors', the weights are scaled such that
  93. the ranges of the first tensors' output channel are equal to the
  94. ranges of the second tensors' input channel
  95. """
  96. if (
  97. type(module1) not in _all_supported_types
  98. or type(module2) not in _all_supported_types
  99. ):
  100. raise ValueError(
  101. "module type not supported:", type(module1), " ", type(module2)
  102. )
  103. bias = get_module_bias(module1) if has_bias(module1) else None
  104. weight1 = get_module_weight(module1)
  105. weight2 = get_module_weight(module2)
  106. if weight1.size(output_axis) != weight2.size(input_axis):
  107. raise TypeError(
  108. "Number of output channels of first arg do not match \
  109. number input channels of second arg"
  110. )
  111. weight1_range = channel_range(weight1, output_axis)
  112. weight2_range = channel_range(weight2, input_axis)
  113. # producing scaling factors to applied
  114. weight2_range += 1e-9
  115. scaling_factors = torch.sqrt(weight1_range / weight2_range)
  116. inverse_scaling_factors = torch.reciprocal(scaling_factors)
  117. if bias is not None:
  118. bias = bias * inverse_scaling_factors
  119. # formatting the scaling (1D) tensors to be applied on the given argument tensors
  120. # pads axis to (1D) tensors to then be broadcasted
  121. size1 = [1] * weight1.ndim
  122. size1[output_axis] = weight1.size(output_axis)
  123. size2 = [1] * weight2.ndim
  124. size2[input_axis] = weight2.size(input_axis)
  125. scaling_factors = torch.reshape(scaling_factors, size2)
  126. inverse_scaling_factors = torch.reshape(inverse_scaling_factors, size1)
  127. weight1 = weight1 * inverse_scaling_factors
  128. weight2 = weight2 * scaling_factors
  129. set_module_weight(module1, weight1)
  130. if bias is not None:
  131. set_module_bias(module1, bias)
  132. set_module_weight(module2, weight2)
  133. def process_paired_modules_list_to_name(model, paired_modules_list):
  134. """Processes a list of paired modules to a list of names of paired modules."""
  135. for group in paired_modules_list:
  136. for i, item in enumerate(group):
  137. if isinstance(item, torch.nn.Module):
  138. group[i] = get_name_by_module(model, item)
  139. elif not isinstance(item, str):
  140. raise TypeError("item must be a nn.Module or a string")
  141. return paired_modules_list
  142. def expand_groups_in_paired_modules_list(paired_modules_list):
  143. """Expands module pair groups larger than two into groups of two modules."""
  144. new_list = []
  145. for group in paired_modules_list:
  146. if len(group) == 1:
  147. raise ValueError("Group must have at least two modules")
  148. elif len(group) == 2:
  149. new_list.append(group)
  150. elif len(group) > 2:
  151. new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1))
  152. return new_list
  153. def equalize(model, paired_modules_list, threshold=1e-4, inplace=True):
  154. """Equalize modules until convergence is achieved.
  155. Given a list of adjacent modules within a model, equalization will
  156. be applied between each pair, this will repeated until convergence is achieved
  157. Keeps a copy of the changing modules from the previous iteration, if the copies
  158. are not that different than the current modules (determined by converged_test),
  159. then the modules have converged enough that further equalizing is not necessary
  160. Reference is section 4.1 of this paper https://arxiv.org/pdf/1906.04721.pdf
  161. Args:
  162. model: a model (nn.Module) that equalization is to be applied on
  163. paired_modules_list (List(List[nn.module || str])): a list of lists
  164. where each sublist is a pair of two submodules found in the model,
  165. for each pair the two modules have to be adjacent in the model,
  166. with only piece-wise-linear functions like a (P)ReLU or LeakyReLU in between
  167. to get expected results.
  168. The list can contain either modules, or names of modules in the model.
  169. If you pass multiple modules in the same list, they will all be equalized together.
  170. threshold (float): a number used by the converged function to determine what degree
  171. of similarity between models is necessary for them to be called equivalent
  172. inplace (bool): determines if function is inplace or not
  173. """
  174. paired_modules_list = process_paired_modules_list_to_name(
  175. model, paired_modules_list
  176. )
  177. if not inplace:
  178. model = copy.deepcopy(model)
  179. paired_modules_list = expand_groups_in_paired_modules_list(paired_modules_list)
  180. name_to_module: dict[str, torch.nn.Module] = {}
  181. previous_name_to_module: dict[str, Any] = {}
  182. name_set = set(chain.from_iterable(paired_modules_list))
  183. for name, module in model.named_modules():
  184. if name in name_set:
  185. name_to_module[name] = module
  186. previous_name_to_module[name] = None
  187. while not converged(name_to_module, previous_name_to_module, threshold):
  188. for pair in paired_modules_list:
  189. previous_name_to_module[pair[0]] = copy.deepcopy(name_to_module[pair[0]])
  190. previous_name_to_module[pair[1]] = copy.deepcopy(name_to_module[pair[1]])
  191. cross_layer_equalization(name_to_module[pair[0]], name_to_module[pair[1]])
  192. return model
  193. def converged(curr_modules, prev_modules, threshold=1e-4):
  194. """Test whether modules are converged to a specified threshold.
  195. Tests for the summed norm of the differences between each set of modules
  196. being less than the given threshold
  197. Takes two dictionaries mapping names to modules, the set of names for each dictionary
  198. should be the same, looping over the set of names, for each name take the difference
  199. between the associated modules in each dictionary
  200. """
  201. if curr_modules.keys() != prev_modules.keys():
  202. raise ValueError(
  203. "The keys to the given mappings must have the same set of names of modules"
  204. )
  205. summed_norms = torch.tensor(0.0)
  206. if None in prev_modules.values():
  207. return False
  208. for name in curr_modules:
  209. curr_weight = get_module_weight(curr_modules[name])
  210. prev_weight = get_module_weight(prev_modules[name])
  211. difference = curr_weight.sub(prev_weight)
  212. summed_norms += torch.norm(difference)
  213. return bool(summed_norms < threshold)