fusion.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. from __future__ import annotations
  2. import copy
  3. from typing import TypeVar
  4. import torch
  5. __all__ = [
  6. "fuse_conv_bn_eval",
  7. "fuse_conv_bn_weights",
  8. "fuse_linear_bn_eval",
  9. "fuse_linear_bn_weights",
  10. ]
  11. ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
  12. LinearT = TypeVar("LinearT", bound="torch.nn.Linear")
  13. def fuse_conv_bn_eval(
  14. conv: ConvT,
  15. bn: torch.nn.modules.batchnorm._BatchNorm,
  16. transpose: bool = False,
  17. ) -> ConvT:
  18. r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module.
  19. Args:
  20. conv (torch.nn.modules.conv._ConvNd): A convolutional module.
  21. bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
  22. transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False.
  23. Returns:
  24. torch.nn.modules.conv._ConvNd: The fused convolutional module.
  25. .. note::
  26. Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
  27. """
  28. if conv.training or bn.training:
  29. raise AssertionError("Fusion only for eval!")
  30. fused_conv = copy.deepcopy(conv)
  31. if bn.running_mean is None or bn.running_var is None:
  32. raise AssertionError("bn.running_mean and bn.running_var must not be None")
  33. fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
  34. fused_conv.weight,
  35. fused_conv.bias,
  36. bn.running_mean,
  37. bn.running_var,
  38. bn.eps,
  39. bn.weight,
  40. bn.bias,
  41. transpose,
  42. )
  43. return fused_conv
  44. def fuse_conv_bn_weights(
  45. conv_w: torch.Tensor,
  46. conv_b: torch.Tensor | None,
  47. bn_rm: torch.Tensor,
  48. bn_rv: torch.Tensor,
  49. bn_eps: float,
  50. bn_w: torch.Tensor | None,
  51. bn_b: torch.Tensor | None,
  52. transpose: bool = False,
  53. ) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
  54. r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
  55. Args:
  56. conv_w (torch.Tensor): Convolutional weight.
  57. conv_b (Optional[torch.Tensor]): Convolutional bias.
  58. bn_rm (torch.Tensor): BatchNorm running mean.
  59. bn_rv (torch.Tensor): BatchNorm running variance.
  60. bn_eps (float): BatchNorm epsilon.
  61. bn_w (Optional[torch.Tensor]): BatchNorm weight.
  62. bn_b (Optional[torch.Tensor]): BatchNorm bias.
  63. transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
  64. Returns:
  65. Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
  66. """
  67. conv_weight_dtype = conv_w.dtype
  68. conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
  69. if conv_b is None:
  70. conv_b = torch.zeros_like(bn_rm)
  71. if bn_w is None:
  72. bn_w = torch.ones_like(bn_rm)
  73. if bn_b is None:
  74. bn_b = torch.zeros_like(bn_rm)
  75. bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
  76. if transpose:
  77. shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
  78. else:
  79. shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
  80. fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(
  81. dtype=conv_weight_dtype
  82. )
  83. fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(
  84. dtype=conv_bias_dtype
  85. )
  86. return (
  87. torch.nn.Parameter(fused_conv_w, conv_w.requires_grad),
  88. torch.nn.Parameter(fused_conv_b, conv_b.requires_grad),
  89. )
  90. def fuse_linear_bn_eval(
  91. linear: LinearT,
  92. bn: torch.nn.modules.batchnorm._BatchNorm,
  93. ) -> LinearT:
  94. r"""Fuse a linear module and a BatchNorm module into a single, new linear module.
  95. Args:
  96. linear (torch.nn.Linear): A Linear module.
  97. bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
  98. Returns:
  99. torch.nn.Linear: The fused linear module.
  100. .. note::
  101. Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
  102. """
  103. if linear.training or bn.training:
  104. raise AssertionError("Fusion only for eval!")
  105. fused_linear = copy.deepcopy(linear)
  106. """
  107. Linear-BN needs to be fused while preserving the shapes of linear weight/bias.
  108. To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear,
  109. because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in).
  110. To be broadcastable, the number of features in bn and
  111. the number of output features from linear must satisfy the following condition:
  112. 1. they are equal, or
  113. 2. the number of features in bn is 1
  114. Otherwise, skip the folding path
  115. """
  116. if linear.out_features != bn.num_features and bn.num_features != 1:
  117. raise AssertionError(
  118. f"To fuse, linear.out_features == bn.num_features or bn.num_features == 1, "
  119. f"got linear.out_features={linear.out_features} and bn.num_features={bn.num_features}"
  120. )
  121. if bn.running_mean is None or bn.running_var is None:
  122. raise AssertionError("bn.running_mean and bn.running_var must not be None")
  123. fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
  124. fused_linear.weight,
  125. fused_linear.bias,
  126. bn.running_mean,
  127. bn.running_var,
  128. bn.eps,
  129. bn.weight,
  130. bn.bias,
  131. )
  132. return fused_linear
  133. def fuse_linear_bn_weights(
  134. linear_w: torch.Tensor,
  135. linear_b: torch.Tensor | None,
  136. bn_rm: torch.Tensor,
  137. bn_rv: torch.Tensor,
  138. bn_eps: float,
  139. bn_w: torch.Tensor,
  140. bn_b: torch.Tensor,
  141. ) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
  142. r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
  143. Args:
  144. linear_w (torch.Tensor): Linear weight.
  145. linear_b (Optional[torch.Tensor]): Linear bias.
  146. bn_rm (torch.Tensor): BatchNorm running mean.
  147. bn_rv (torch.Tensor): BatchNorm running variance.
  148. bn_eps (float): BatchNorm epsilon.
  149. bn_w (torch.Tensor): BatchNorm weight.
  150. bn_b (torch.Tensor): BatchNorm bias.
  151. Returns:
  152. Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
  153. """
  154. linear_weight_dtype = linear_w.dtype
  155. linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype
  156. if linear_b is None:
  157. linear_b = torch.zeros_like(bn_rm)
  158. bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
  159. fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype)
  160. fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype)
  161. return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
  162. fused_b, linear_b.requires_grad
  163. )