misc.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. import warnings
  2. from collections.abc import Sequence
  3. from typing import Callable, Optional, Union
  4. import torch
  5. from torch import Tensor
  6. from ..utils import _log_api_usage_once, _make_ntuple
  7. interpolate = torch.nn.functional.interpolate
  8. class FrozenBatchNorm2d(torch.nn.Module):
  9. """
  10. BatchNorm2d where the batch statistics and the affine parameters are fixed
  11. Args:
  12. num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
  13. eps (float): a value added to the denominator for numerical stability. Default: 1e-5
  14. """
  15. def __init__(
  16. self,
  17. num_features: int,
  18. eps: float = 1e-5,
  19. ):
  20. super().__init__()
  21. _log_api_usage_once(self)
  22. self.eps = eps
  23. self.register_buffer("weight", torch.ones(num_features))
  24. self.register_buffer("bias", torch.zeros(num_features))
  25. self.register_buffer("running_mean", torch.zeros(num_features))
  26. self.register_buffer("running_var", torch.ones(num_features))
  27. def _load_from_state_dict(
  28. self,
  29. state_dict: dict,
  30. prefix: str,
  31. local_metadata: dict,
  32. strict: bool,
  33. missing_keys: list[str],
  34. unexpected_keys: list[str],
  35. error_msgs: list[str],
  36. ):
  37. num_batches_tracked_key = prefix + "num_batches_tracked"
  38. if num_batches_tracked_key in state_dict:
  39. del state_dict[num_batches_tracked_key]
  40. super()._load_from_state_dict(
  41. state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
  42. )
  43. def forward(self, x: Tensor) -> Tensor:
  44. # move reshapes to the beginning
  45. # to make it fuser-friendly
  46. w = self.weight.reshape(1, -1, 1, 1)
  47. b = self.bias.reshape(1, -1, 1, 1)
  48. rv = self.running_var.reshape(1, -1, 1, 1)
  49. rm = self.running_mean.reshape(1, -1, 1, 1)
  50. scale = w * (rv + self.eps).rsqrt()
  51. bias = b - rm * scale
  52. return x * scale + bias
  53. def __repr__(self) -> str:
  54. return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})"
  55. class ConvNormActivation(torch.nn.Sequential):
  56. def __init__(
  57. self,
  58. in_channels: int,
  59. out_channels: int,
  60. kernel_size: Union[int, tuple[int, ...]] = 3,
  61. stride: Union[int, tuple[int, ...]] = 1,
  62. padding: Optional[Union[int, tuple[int, ...], str]] = None,
  63. groups: int = 1,
  64. norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
  65. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  66. dilation: Union[int, tuple[int, ...]] = 1,
  67. inplace: Optional[bool] = True,
  68. bias: Optional[bool] = None,
  69. conv_layer: Callable[..., torch.nn.Module] = torch.nn.Conv2d,
  70. ) -> None:
  71. if padding is None:
  72. if isinstance(kernel_size, int) and isinstance(dilation, int):
  73. padding = (kernel_size - 1) // 2 * dilation
  74. else:
  75. _conv_dim = len(kernel_size) if isinstance(kernel_size, Sequence) else len(dilation)
  76. kernel_size = _make_ntuple(kernel_size, _conv_dim)
  77. dilation = _make_ntuple(dilation, _conv_dim)
  78. padding = tuple((kernel_size[i] - 1) // 2 * dilation[i] for i in range(_conv_dim))
  79. if bias is None:
  80. bias = norm_layer is None
  81. layers = [
  82. conv_layer(
  83. in_channels,
  84. out_channels,
  85. kernel_size,
  86. stride,
  87. padding,
  88. dilation=dilation,
  89. groups=groups,
  90. bias=bias,
  91. )
  92. ]
  93. if norm_layer is not None:
  94. layers.append(norm_layer(out_channels))
  95. if activation_layer is not None:
  96. params = {} if inplace is None else {"inplace": inplace}
  97. layers.append(activation_layer(**params))
  98. super().__init__(*layers)
  99. _log_api_usage_once(self)
  100. self.out_channels = out_channels
  101. if self.__class__ == ConvNormActivation:
  102. warnings.warn(
  103. "Don't use ConvNormActivation directly, please use Conv2dNormActivation and Conv3dNormActivation instead."
  104. )
  105. class Conv2dNormActivation(ConvNormActivation):
  106. """
  107. Configurable block used for Convolution2d-Normalization-Activation blocks.
  108. Args:
  109. in_channels (int): Number of channels in the input image
  110. out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
  111. kernel_size: (int, optional): Size of the convolving kernel. Default: 3
  112. stride (int, optional): Stride of the convolution. Default: 1
  113. padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
  114. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  115. norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm2d``
  116. activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
  117. dilation (int): Spacing between kernel elements. Default: 1
  118. inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
  119. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
  120. """
  121. def __init__(
  122. self,
  123. in_channels: int,
  124. out_channels: int,
  125. kernel_size: Union[int, tuple[int, int]] = 3,
  126. stride: Union[int, tuple[int, int]] = 1,
  127. padding: Optional[Union[int, tuple[int, int], str]] = None,
  128. groups: int = 1,
  129. norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d,
  130. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  131. dilation: Union[int, tuple[int, int]] = 1,
  132. inplace: Optional[bool] = True,
  133. bias: Optional[bool] = None,
  134. ) -> None:
  135. super().__init__(
  136. in_channels,
  137. out_channels,
  138. kernel_size,
  139. stride,
  140. padding,
  141. groups,
  142. norm_layer,
  143. activation_layer,
  144. dilation,
  145. inplace,
  146. bias,
  147. torch.nn.Conv2d,
  148. )
  149. class Conv3dNormActivation(ConvNormActivation):
  150. """
  151. Configurable block used for Convolution3d-Normalization-Activation blocks.
  152. Args:
  153. in_channels (int): Number of channels in the input video.
  154. out_channels (int): Number of channels produced by the Convolution-Normalization-Activation block
  155. kernel_size: (int, optional): Size of the convolving kernel. Default: 3
  156. stride (int, optional): Stride of the convolution. Default: 1
  157. padding (int, tuple or str, optional): Padding added to all four sides of the input. Default: None, in which case it will be calculated as ``padding = (kernel_size - 1) // 2 * dilation``
  158. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
  159. norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the convolution layer. If ``None`` this layer won't be used. Default: ``torch.nn.BatchNorm3d``
  160. activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the conv layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
  161. dilation (int): Spacing between kernel elements. Default: 1
  162. inplace (bool): Parameter for the activation layer, which can optionally do the operation in-place. Default ``True``
  163. bias (bool, optional): Whether to use bias in the convolution layer. By default, biases are included if ``norm_layer is None``.
  164. """
  165. def __init__(
  166. self,
  167. in_channels: int,
  168. out_channels: int,
  169. kernel_size: Union[int, tuple[int, int, int]] = 3,
  170. stride: Union[int, tuple[int, int, int]] = 1,
  171. padding: Optional[Union[int, tuple[int, int, int], str]] = None,
  172. groups: int = 1,
  173. norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm3d,
  174. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  175. dilation: Union[int, tuple[int, int, int]] = 1,
  176. inplace: Optional[bool] = True,
  177. bias: Optional[bool] = None,
  178. ) -> None:
  179. super().__init__(
  180. in_channels,
  181. out_channels,
  182. kernel_size,
  183. stride,
  184. padding,
  185. groups,
  186. norm_layer,
  187. activation_layer,
  188. dilation,
  189. inplace,
  190. bias,
  191. torch.nn.Conv3d,
  192. )
  193. class SqueezeExcitation(torch.nn.Module):
  194. """
  195. This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).
  196. Parameters ``activation``, and ``scale_activation`` correspond to ``delta`` and ``sigma`` in eq. 3.
  197. Args:
  198. input_channels (int): Number of channels in the input image
  199. squeeze_channels (int): Number of squeeze channels
  200. activation (Callable[..., torch.nn.Module], optional): ``delta`` activation. Default: ``torch.nn.ReLU``
  201. scale_activation (Callable[..., torch.nn.Module]): ``sigma`` activation. Default: ``torch.nn.Sigmoid``
  202. """
  203. def __init__(
  204. self,
  205. input_channels: int,
  206. squeeze_channels: int,
  207. activation: Callable[..., torch.nn.Module] = torch.nn.ReLU,
  208. scale_activation: Callable[..., torch.nn.Module] = torch.nn.Sigmoid,
  209. ) -> None:
  210. super().__init__()
  211. _log_api_usage_once(self)
  212. self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
  213. self.fc1 = torch.nn.Conv2d(input_channels, squeeze_channels, 1)
  214. self.fc2 = torch.nn.Conv2d(squeeze_channels, input_channels, 1)
  215. self.activation = activation()
  216. self.scale_activation = scale_activation()
  217. def _scale(self, input: Tensor) -> Tensor:
  218. scale = self.avgpool(input)
  219. scale = self.fc1(scale)
  220. scale = self.activation(scale)
  221. scale = self.fc2(scale)
  222. return self.scale_activation(scale)
  223. def forward(self, input: Tensor) -> Tensor:
  224. scale = self._scale(input)
  225. return scale * input
  226. class MLP(torch.nn.Sequential):
  227. """This block implements the multi-layer perceptron (MLP) module.
  228. Args:
  229. in_channels (int): Number of channels of the input
  230. hidden_channels (List[int]): List of the hidden channel dimensions
  231. norm_layer (Callable[..., torch.nn.Module], optional): Norm layer that will be stacked on top of the linear layer. If ``None`` this layer won't be used. Default: ``None``
  232. activation_layer (Callable[..., torch.nn.Module], optional): Activation function which will be stacked on top of the normalization layer (if not None), otherwise on top of the linear layer. If ``None`` this layer won't be used. Default: ``torch.nn.ReLU``
  233. inplace (bool, optional): Parameter for the activation layer, which can optionally do the operation in-place.
  234. Default is ``None``, which uses the respective default values of the ``activation_layer`` and Dropout layer.
  235. bias (bool): Whether to use bias in the linear layer. Default ``True``
  236. dropout (float): The probability for the dropout layer. Default: 0.0
  237. """
  238. def __init__(
  239. self,
  240. in_channels: int,
  241. hidden_channels: list[int],
  242. norm_layer: Optional[Callable[..., torch.nn.Module]] = None,
  243. activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU,
  244. inplace: Optional[bool] = None,
  245. bias: bool = True,
  246. dropout: float = 0.0,
  247. ):
  248. # The addition of `norm_layer` is inspired from the implementation of TorchMultimodal:
  249. # https://github.com/facebookresearch/multimodal/blob/5dec8a/torchmultimodal/modules/layers/mlp.py
  250. params = {} if inplace is None else {"inplace": inplace}
  251. layers = []
  252. in_dim = in_channels
  253. for hidden_dim in hidden_channels[:-1]:
  254. layers.append(torch.nn.Linear(in_dim, hidden_dim, bias=bias))
  255. if norm_layer is not None:
  256. layers.append(norm_layer(hidden_dim))
  257. layers.append(activation_layer(**params))
  258. layers.append(torch.nn.Dropout(dropout, **params))
  259. in_dim = hidden_dim
  260. layers.append(torch.nn.Linear(in_dim, hidden_channels[-1], bias=bias))
  261. layers.append(torch.nn.Dropout(dropout, **params))
  262. super().__init__(*layers)
  263. _log_api_usage_once(self)
  264. class Permute(torch.nn.Module):
  265. """This module returns a view of the tensor input with its dimensions permuted.
  266. Args:
  267. dims (List[int]): The desired ordering of dimensions
  268. """
  269. def __init__(self, dims: list[int]):
  270. super().__init__()
  271. self.dims = dims
  272. def forward(self, x: Tensor) -> Tensor:
  273. return torch.permute(x, self.dims)