drop_block.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import torch
  2. import torch.fx
  3. import torch.nn.functional as F
  4. from torch import nn, Tensor
  5. from ..utils import _log_api_usage_once
  6. def drop_block2d(
  7. input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
  8. ) -> Tensor:
  9. """
  10. Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks"
  11. <https://arxiv.org/abs/1810.12890>`.
  12. Args:
  13. input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one
  14. being its batch i.e. a batch with ``N`` rows.
  15. p (float): Probability of an element to be dropped.
  16. block_size (int): Size of the block to drop.
  17. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
  18. eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
  19. training (bool): apply dropblock if is ``True``. Default: ``True``.
  20. Returns:
  21. Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock.
  22. """
  23. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  24. _log_api_usage_once(drop_block2d)
  25. if p < 0.0 or p > 1.0:
  26. raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
  27. if input.ndim != 4:
  28. raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.")
  29. if not training or p == 0.0:
  30. return input
  31. N, C, H, W = input.size()
  32. block_size = min(block_size, W, H)
  33. if block_size % 2 == 0:
  34. raise ValueError(f"block size should be odd. Got {block_size} which is even.")
  35. # compute the gamma of Bernoulli distribution
  36. gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1)))
  37. noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device)
  38. noise.bernoulli_(gamma)
  39. noise = F.pad(noise, [block_size // 2] * 4, value=0)
  40. noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2)
  41. noise = 1 - noise
  42. normalize_scale = noise.numel() / (eps + noise.sum())
  43. if inplace:
  44. input.mul_(noise).mul_(normalize_scale)
  45. else:
  46. input = input * noise * normalize_scale
  47. return input
  48. def drop_block3d(
  49. input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
  50. ) -> Tensor:
  51. """
  52. Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks"
  53. <https://arxiv.org/abs/1810.12890>`.
  54. Args:
  55. input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one
  56. being its batch i.e. a batch with ``N`` rows.
  57. p (float): Probability of an element to be dropped.
  58. block_size (int): Size of the block to drop.
  59. inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
  60. eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
  61. training (bool): apply dropblock if is ``True``. Default: ``True``.
  62. Returns:
  63. Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock.
  64. """
  65. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  66. _log_api_usage_once(drop_block3d)
  67. if p < 0.0 or p > 1.0:
  68. raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
  69. if input.ndim != 5:
  70. raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.")
  71. if not training or p == 0.0:
  72. return input
  73. N, C, D, H, W = input.size()
  74. block_size = min(block_size, D, H, W)
  75. if block_size % 2 == 0:
  76. raise ValueError(f"block size should be odd. Got {block_size} which is even.")
  77. # compute the gamma of Bernoulli distribution
  78. gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1)))
  79. noise = torch.empty(
  80. (N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device
  81. )
  82. noise.bernoulli_(gamma)
  83. noise = F.pad(noise, [block_size // 2] * 6, value=0)
  84. noise = F.max_pool3d(
  85. noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2
  86. )
  87. noise = 1 - noise
  88. normalize_scale = noise.numel() / (eps + noise.sum())
  89. if inplace:
  90. input.mul_(noise).mul_(normalize_scale)
  91. else:
  92. input = input * noise * normalize_scale
  93. return input
  94. torch.fx.wrap("drop_block2d")
  95. class DropBlock2d(nn.Module):
  96. """
  97. See :func:`drop_block2d`.
  98. """
  99. def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
  100. super().__init__()
  101. self.p = p
  102. self.block_size = block_size
  103. self.inplace = inplace
  104. self.eps = eps
  105. def forward(self, input: Tensor) -> Tensor:
  106. """
  107. Args:
  108. input (Tensor): Input feature map on which some areas will be randomly
  109. dropped.
  110. Returns:
  111. Tensor: The tensor after DropBlock layer.
  112. """
  113. return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training)
  114. def __repr__(self) -> str:
  115. s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})"
  116. return s
  117. torch.fx.wrap("drop_block3d")
  118. class DropBlock3d(DropBlock2d):
  119. """
  120. See :func:`drop_block3d`.
  121. """
  122. def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
  123. super().__init__(p, block_size, inplace, eps)
  124. def forward(self, input: Tensor) -> Tensor:
  125. """
  126. Args:
  127. input (Tensor): Input feature map on which some areas will be randomly
  128. dropped.
  129. Returns:
  130. Tensor: The tensor after DropBlock layer.
  131. """
  132. return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)