drop.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. """ DropBlock, DropPath
  2. PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
  3. Papers:
  4. DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
  5. Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
  6. Code:
  7. DropBlock impl inspired by two Tensorflow impl that I liked:
  8. - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
  9. - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
  10. Hacked together by / Copyright 2020 Ross Wightman
  11. """
  12. from typing import List, Union
  13. import torch
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. def drop_block_2d(
  17. x: torch.Tensor,
  18. drop_prob: float = 0.1,
  19. block_size: int = 7,
  20. gamma_scale: float = 1.0,
  21. with_noise: bool = False,
  22. inplace: bool = False,
  23. couple_channels: bool = True,
  24. scale_by_keep: bool = True,
  25. ):
  26. """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
  27. DropBlock with an experimental gaussian noise option.
  28. Args:
  29. x: Input tensor of shape (B, C, H, W).
  30. drop_prob: Probability of dropping a block.
  31. block_size: Size of the block to drop.
  32. gamma_scale: Scale factor for the drop probability.
  33. with_noise: If True, add gaussian noise to dropped regions instead of zeros.
  34. inplace: If True, perform operation in-place.
  35. couple_channels: If True, all channels share the same drop mask (per the original paper).
  36. If False, each channel gets an independent mask.
  37. scale_by_keep: If True, scale kept activations to maintain expected values.
  38. Returns:
  39. Tensor with dropped blocks, same shape as input.
  40. """
  41. B, C, H, W = x.shape
  42. kh, kw = min(block_size, H), min(block_size, W)
  43. # Compute gamma (seed drop rate) - probability of dropping each spatial location
  44. gamma = float(gamma_scale * drop_prob * H * W) / float(kh * kw) / float((H - kh + 1) * (W - kw + 1))
  45. # Generate drop mask: 1 at block centers to drop, 0 elsewhere
  46. # couple_channels=True means all channels share same spatial mask (matches paper)
  47. noise_shape = (B, 1 if couple_channels else C, H, W)
  48. with torch.no_grad():
  49. block_mask = torch.empty(noise_shape, dtype=x.dtype, device=x.device).bernoulli_(gamma)
  50. # Expand block centers to full blocks using max pooling
  51. block_mask = F.max_pool2d(
  52. block_mask,
  53. kernel_size=(kh, kw),
  54. stride=1,
  55. padding=(kh // 2, kw // 2),
  56. )
  57. # Handle even kernel sizes - max_pool2d output is 1 larger in each even dimension
  58. if kh % 2 == 0 or kw % 2 == 0:
  59. # Fix for even kernels proposed by https://github.com/crutcher
  60. block_mask = block_mask[..., (kh + 1) % 2:, (kw + 1) % 2:]
  61. keep_mask = 1. - block_mask
  62. if with_noise:
  63. with torch.no_grad():
  64. noise = torch.empty_like(keep_mask).normal_()
  65. noise.mul_(block_mask)
  66. if inplace:
  67. x.mul_(keep_mask).add_(noise)
  68. else:
  69. x = x * keep_mask + noise
  70. else:
  71. if scale_by_keep:
  72. with torch.no_grad():
  73. # Normalize to maintain expected values (scale up kept activations)
  74. normalize_scale = keep_mask.numel() / keep_mask.to(dtype=torch.float32).sum().add(1e-7)
  75. keep_mask.mul_(normalize_scale.to(x.dtype))
  76. if inplace:
  77. x.mul_(keep_mask)
  78. else:
  79. x = x * keep_mask
  80. return x
  81. class DropBlock2d(nn.Module):
  82. """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
  83. Args:
  84. drop_prob: Probability of dropping a block.
  85. block_size: Size of the block to drop.
  86. gamma_scale: Scale factor for the drop probability.
  87. with_noise: If True, add gaussian noise to dropped regions instead of zeros.
  88. inplace: If True, perform operation in-place.
  89. couple_channels: If True, all channels share the same drop mask (per the original paper).
  90. If False, each channel gets an independent mask.
  91. scale_by_keep: If True, scale kept activations to maintain expected values.
  92. """
  93. def __init__(
  94. self,
  95. drop_prob: float = 0.1,
  96. block_size: int = 7,
  97. gamma_scale: float = 1.0,
  98. with_noise: bool = False,
  99. inplace: bool = False,
  100. couple_channels: bool = True,
  101. scale_by_keep: bool = True,
  102. **kwargs,
  103. ):
  104. super().__init__()
  105. self.drop_prob = drop_prob
  106. self.gamma_scale = gamma_scale
  107. self.block_size = block_size
  108. self.with_noise = with_noise
  109. self.inplace = inplace
  110. self.couple_channels = couple_channels
  111. self.scale_by_keep = scale_by_keep
  112. # Backwards compatibility: silently consume args removed in v1.0.23, warn on unknown
  113. deprecated_args = {'batchwise', 'fast'}
  114. for k in kwargs:
  115. if k not in deprecated_args:
  116. import warnings
  117. warnings.warn(f"DropBlock2d() got unexpected keyword argument '{k}'")
  118. def forward(self, x):
  119. if not self.training or not self.drop_prob:
  120. return x
  121. return drop_block_2d(
  122. x,
  123. drop_prob=self.drop_prob,
  124. block_size=self.block_size,
  125. gamma_scale=self.gamma_scale,
  126. with_noise=self.with_noise,
  127. inplace=self.inplace,
  128. couple_channels=self.couple_channels,
  129. scale_by_keep=self.scale_by_keep,
  130. )
  131. def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
  132. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  133. This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
  134. the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
  135. See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
  136. changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
  137. 'survival rate' as the argument.
  138. """
  139. if drop_prob == 0. or not training:
  140. return x
  141. keep_prob = 1 - drop_prob
  142. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  143. random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
  144. if keep_prob > 0.0 and scale_by_keep:
  145. random_tensor.div_(keep_prob)
  146. return x * random_tensor
  147. class DropPath(nn.Module):
  148. """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  149. """
  150. def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
  151. super().__init__()
  152. self.drop_prob = drop_prob
  153. self.scale_by_keep = scale_by_keep
  154. def forward(self, x):
  155. return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
  156. def extra_repr(self):
  157. return f'drop_prob={round(self.drop_prob,3):0.3f}'
  158. def calculate_drop_path_rates(
  159. drop_path_rate: float,
  160. depths: Union[int, List[int]],
  161. stagewise: bool = False,
  162. ) -> Union[List[float], List[List[float]]]:
  163. """Generate drop path rates for stochastic depth.
  164. This function handles two common patterns for drop path rate scheduling:
  165. 1. Per-block: Linear increase from 0 to drop_path_rate across all blocks
  166. 2. Stage-wise: Linear increase across stages, with same rate within each stage
  167. Args:
  168. drop_path_rate: Maximum drop path rate (at the end).
  169. depths: Either a single int for total depth (per-block mode) or
  170. list of ints for depths per stage (stage-wise mode).
  171. stagewise: If True, use stage-wise pattern. If False, use per-block pattern.
  172. When depths is a list, stagewise defaults to True.
  173. Returns:
  174. For per-block mode: List of drop rates, one per block.
  175. For stage-wise mode: List of lists, drop rates per stage.
  176. """
  177. if isinstance(depths, int):
  178. # Single depth value - per-block pattern
  179. if stagewise:
  180. raise ValueError("stagewise=True requires depths to be a list of stage depths")
  181. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depths, device='cpu')]
  182. return dpr
  183. else:
  184. # List of depths - can be either pattern
  185. total_depth = sum(depths)
  186. if stagewise:
  187. # Stage-wise pattern: same drop rate within each stage
  188. dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu').split(depths)]
  189. return dpr
  190. else:
  191. # Per-block pattern across all stages
  192. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth, device='cpu')]
  193. return dpr