common.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from __future__ import annotations
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from kornia.core import Module, Tensor, pad
  22. class ConvNormAct(nn.Sequential):
  23. def __init__(
  24. self,
  25. in_channels: int,
  26. out_channels: int,
  27. kernel_size: int,
  28. stride: int = 1,
  29. act: str = "relu",
  30. groups: int = 1,
  31. conv_naming: str = "conv",
  32. norm_naming: str = "norm",
  33. act_naming: str = "act",
  34. ) -> None:
  35. super().__init__()
  36. if kernel_size % 2 == 0:
  37. # even kernel_size -> asymmetric padding
  38. # PPHGNetV2 (for RT-DETR) uses kernel 2
  39. # follow TensorFlow/PaddlePaddle: bottom/right side is padded 1 more than top/left
  40. # NOTE: this does not account for stride=2
  41. p1 = (kernel_size - 1) // 2
  42. p2 = kernel_size - 1 - p1
  43. self.pad = nn.ZeroPad2d((p1, p2, p1, p2))
  44. padding = 0
  45. else:
  46. padding = (kernel_size - 1) // 2
  47. conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, 1, groups, False)
  48. norm = nn.BatchNorm2d(out_channels)
  49. activation = {"relu": nn.ReLU, "silu": nn.SiLU, "none": nn.Identity}[act](inplace=True)
  50. self.__setattr__(conv_naming, conv)
  51. self.__setattr__(norm_naming, norm)
  52. self.__setattr__(act_naming, activation)
  53. # Lightly adapted from
  54. # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py
  55. class MLP(Module):
  56. def __init__(
  57. self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
  58. ) -> None:
  59. super().__init__()
  60. self.num_layers = num_layers
  61. h = [hidden_dim] * (num_layers - 1)
  62. self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim]))
  63. self.sigmoid_output = sigmoid_output
  64. def forward(self, x: Tensor) -> Tensor:
  65. for i, layer in enumerate(self.layers):
  66. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  67. if self.sigmoid_output:
  68. x = F.sigmoid(x)
  69. return x
  70. # Adapted from timm
  71. # https://github.com/huggingface/pytorch-image-models/blob/v0.9.2/timm/layers/drop.py#L137
  72. class DropPath(Module):
  73. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
  74. def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True) -> None:
  75. super().__init__()
  76. self.drop_prob = drop_prob
  77. self.scale_by_keep = scale_by_keep
  78. def forward(self, x: Tensor) -> Tensor:
  79. if self.drop_prob == 0.0 or not self.training:
  80. return x
  81. keep_prob = 1 - self.drop_prob
  82. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  83. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  84. if keep_prob > 0.0 and self.scale_by_keep:
  85. random_tensor.div_(keep_prob)
  86. return x * random_tensor
  87. # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py
  88. # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
  89. class LayerNorm2d(Module):
  90. def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
  91. super().__init__()
  92. self.weight = nn.Parameter(torch.ones(num_channels))
  93. self.bias = nn.Parameter(torch.zeros(num_channels))
  94. self.eps = eps
  95. def forward(self, x: Tensor) -> Tensor:
  96. u = x.mean(1, keepdim=True)
  97. s = (x - u).pow(2).mean(1, keepdim=True)
  98. x = (x - u) / (s + self.eps).sqrt()
  99. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  100. return x
  101. def window_partition(x: Tensor, window_size: int) -> tuple[Tensor, tuple[int, int]]:
  102. """Partition into non-overlapping windows with padding if needed.
  103. Args:
  104. x: input tokens with [B, H, W, C].
  105. window_size: window size.
  106. Returns:
  107. windows: windows after partition with [B * num_windows, window_size, window_size, C].
  108. (Hp, Wp): padded height and width before partition
  109. """
  110. B, H, W, C = x.shape
  111. pad_h = (window_size - H % window_size) % window_size
  112. pad_w = (window_size - W % window_size) % window_size
  113. if pad_h > 0 or pad_w > 0:
  114. x = pad(x, (0, 0, 0, pad_w, 0, pad_h))
  115. Hp, Wp = H + pad_h, W + pad_w
  116. x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
  117. windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
  118. return windows, (Hp, Wp)
  119. def window_unpartition(windows: Tensor, window_size: int, pad_hw: tuple[int, int], hw: tuple[int, int]) -> Tensor:
  120. """Window unpartition into original sequences and removing padding.
  121. Args:
  122. windows: input tokens with [B * num_windows, window_size, window_size, C].
  123. window_size: window size.
  124. pad_hw: padded height and width (Hp, Wp).
  125. hw: original height and width (H, W) before padding.
  126. Returns:
  127. x: unpartitioned sequences with [B, H, W, C].
  128. """
  129. Hp, Wp = pad_hw
  130. H, W = hw
  131. B = windows.shape[0] // (Hp * Wp // window_size // window_size)
  132. x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
  133. x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
  134. if Hp > H or Wp > W:
  135. x = x[:, :H, :W, :].contiguous()
  136. return x