dropout.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. import torch.nn.functional as F
  2. from torch import Tensor
  3. from .module import Module
  4. __all__ = [
  5. "Dropout",
  6. "Dropout1d",
  7. "Dropout2d",
  8. "Dropout3d",
  9. "AlphaDropout",
  10. "FeatureAlphaDropout",
  11. ]
  12. class _DropoutNd(Module):
  13. __constants__ = ["p", "inplace"]
  14. p: float
  15. inplace: bool
  16. def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
  17. super().__init__()
  18. if p < 0 or p > 1:
  19. raise ValueError(
  20. f"dropout probability has to be between 0 and 1, but got {p}"
  21. )
  22. self.p = p
  23. self.inplace = inplace
  24. def extra_repr(self) -> str:
  25. return f"p={self.p}, inplace={self.inplace}"
  26. class Dropout(_DropoutNd):
  27. r"""During training, randomly zeroes some of the elements of the input tensor with probability :attr:`p`.
  28. The zeroed elements are chosen independently for each forward call and are sampled from a Bernoulli distribution.
  29. Each channel will be zeroed out independently on every forward call.
  30. This has proven to be an effective technique for regularization and
  31. preventing the co-adaptation of neurons as described in the paper
  32. `Improving neural networks by preventing co-adaptation of feature
  33. detectors`_ .
  34. Furthermore, the outputs are scaled by a factor of :math:`\frac{1}{1-p}` during
  35. training. This means that during evaluation the module simply computes an
  36. identity function.
  37. Args:
  38. p: probability of an element to be zeroed. Default: 0.5
  39. inplace: If set to ``True``, will do this operation in-place. Default: ``False``
  40. Shape:
  41. - Input: :math:`(*)`. Input can be of any shape
  42. - Output: :math:`(*)`. Output is of the same shape as input
  43. Examples::
  44. >>> m = nn.Dropout(p=0.2)
  45. >>> input = torch.randn(20, 16)
  46. >>> output = m(input)
  47. .. _Improving neural networks by preventing co-adaptation of feature
  48. detectors: https://arxiv.org/abs/1207.0580
  49. """
  50. def forward(self, input: Tensor) -> Tensor:
  51. """
  52. Runs the forward pass.
  53. """
  54. return F.dropout(input, self.p, self.training, self.inplace)
  55. class Dropout1d(_DropoutNd):
  56. r"""Randomly zero out entire channels.
  57. A channel is a 1D feature map,
  58. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  59. batched input is a 1D tensor :math:`\text{input}[i, j]`.
  60. Each channel will be zeroed out independently on every forward call with
  61. probability :attr:`p` using samples from a Bernoulli distribution.
  62. Usually the input comes from :class:`nn.Conv1d` modules.
  63. As described in the paper
  64. `Efficient Object Localization Using Convolutional Networks`_ ,
  65. if adjacent pixels within feature maps are strongly correlated
  66. (as is normally the case in early convolution layers) then i.i.d. dropout
  67. will not regularize the activations and will otherwise just result
  68. in an effective learning rate decrease.
  69. In this case, :func:`nn.Dropout1d` will help promote independence between
  70. feature maps and should be used instead.
  71. Args:
  72. p (float, optional): probability of an element to be zero-ed.
  73. inplace (bool, optional): If set to ``True``, will do this operation
  74. in-place
  75. Shape:
  76. - Input: :math:`(N, C, L)` or :math:`(C, L)`.
  77. - Output: :math:`(N, C, L)` or :math:`(C, L)` (same shape as input).
  78. Examples::
  79. >>> m = nn.Dropout1d(p=0.2)
  80. >>> input = torch.randn(20, 16, 32)
  81. >>> output = m(input)
  82. .. _Efficient Object Localization Using Convolutional Networks:
  83. https://arxiv.org/abs/1411.4280
  84. """
  85. def forward(self, input: Tensor) -> Tensor:
  86. """
  87. Runs the forward pass.
  88. """
  89. return F.dropout1d(input, self.p, self.training, self.inplace)
  90. class Dropout2d(_DropoutNd):
  91. r"""Randomly zero out entire channels.
  92. A channel is a 2D feature map,
  93. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  94. batched input is a 2D tensor :math:`\text{input}[i, j]`.
  95. Each channel will be zeroed out independently on every forward call with
  96. probability :attr:`p` using samples from a Bernoulli distribution.
  97. Usually the input comes from :class:`nn.Conv2d` modules.
  98. As described in the paper
  99. `Efficient Object Localization Using Convolutional Networks`_ ,
  100. if adjacent pixels within feature maps are strongly correlated
  101. (as is normally the case in early convolution layers) then i.i.d. dropout
  102. will not regularize the activations and will otherwise just result
  103. in an effective learning rate decrease.
  104. In this case, :func:`nn.Dropout2d` will help promote independence between
  105. feature maps and should be used instead.
  106. Args:
  107. p (float, optional): probability of an element to be zero-ed.
  108. inplace (bool, optional): If set to ``True``, will do this operation
  109. in-place
  110. .. warning ::
  111. Due to historical reasons, this class will perform 1D channel-wise dropout
  112. for 3D inputs (as done by :class:`nn.Dropout1d`). Thus, it currently does NOT
  113. support inputs without a batch dimension of shape :math:`(C, H, W)`. This
  114. behavior will change in a future release to interpret 3D inputs as no-batch-dim
  115. inputs. To maintain the old behavior, switch to :class:`nn.Dropout1d`.
  116. Shape:
  117. - Input: :math:`(N, C, H, W)` or :math:`(N, C, L)`.
  118. - Output: :math:`(N, C, H, W)` or :math:`(N, C, L)` (same shape as input).
  119. Examples::
  120. >>> m = nn.Dropout2d(p=0.2)
  121. >>> input = torch.randn(20, 16, 32, 32)
  122. >>> output = m(input)
  123. .. _Efficient Object Localization Using Convolutional Networks:
  124. https://arxiv.org/abs/1411.4280
  125. """
  126. def forward(self, input: Tensor) -> Tensor:
  127. """
  128. Runs the forward pass.
  129. """
  130. return F.dropout2d(input, self.p, self.training, self.inplace)
  131. class Dropout3d(_DropoutNd):
  132. r"""Randomly zero out entire channels.
  133. A channel is a 3D feature map,
  134. e.g., the :math:`j`-th channel of the :math:`i`-th sample in the
  135. batched input is a 3D tensor :math:`\text{input}[i, j]`.
  136. Each channel will be zeroed out independently on every forward call with
  137. probability :attr:`p` using samples from a Bernoulli distribution.
  138. Usually the input comes from :class:`nn.Conv3d` modules.
  139. As described in the paper
  140. `Efficient Object Localization Using Convolutional Networks`_ ,
  141. if adjacent pixels within feature maps are strongly correlated
  142. (as is normally the case in early convolution layers) then i.i.d. dropout
  143. will not regularize the activations and will otherwise just result
  144. in an effective learning rate decrease.
  145. In this case, :func:`nn.Dropout3d` will help promote independence between
  146. feature maps and should be used instead.
  147. Args:
  148. p (float, optional): probability of an element to be zeroed.
  149. inplace (bool, optional): If set to ``True``, will do this operation
  150. in-place
  151. Shape:
  152. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  153. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  154. Examples::
  155. >>> m = nn.Dropout3d(p=0.2)
  156. >>> input = torch.randn(20, 16, 4, 32, 32)
  157. >>> output = m(input)
  158. .. _Efficient Object Localization Using Convolutional Networks:
  159. https://arxiv.org/abs/1411.4280
  160. """
  161. def forward(self, input: Tensor) -> Tensor:
  162. """
  163. Runs the forward pass.
  164. """
  165. return F.dropout3d(input, self.p, self.training, self.inplace)
  166. class AlphaDropout(_DropoutNd):
  167. r"""Applies Alpha Dropout over the input.
  168. Alpha Dropout is a type of Dropout that maintains the self-normalizing
  169. property.
  170. For an input with zero mean and unit standard deviation, the output of
  171. Alpha Dropout maintains the original mean and standard deviation of the
  172. input.
  173. Alpha Dropout goes hand-in-hand with SELU activation function, which ensures
  174. that the outputs have zero mean and unit standard deviation.
  175. During training, it randomly masks some of the elements of the input
  176. tensor with probability *p* using samples from a bernoulli distribution.
  177. The elements to masked are randomized on every forward call, and scaled
  178. and shifted to maintain zero mean and unit standard deviation.
  179. During evaluation the module simply computes an identity function.
  180. More details can be found in the paper `Self-Normalizing Neural Networks`_ .
  181. Args:
  182. p (float): probability of an element to be dropped. Default: 0.5
  183. inplace (bool, optional): If set to ``True``, will do this operation
  184. in-place
  185. Shape:
  186. - Input: :math:`(*)`. Input can be of any shape
  187. - Output: :math:`(*)`. Output is of the same shape as input
  188. Examples::
  189. >>> m = nn.AlphaDropout(p=0.2)
  190. >>> input = torch.randn(20, 16)
  191. >>> output = m(input)
  192. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  193. """
  194. def forward(self, input: Tensor) -> Tensor:
  195. """
  196. Runs the forward pass.
  197. """
  198. return F.alpha_dropout(input, self.p, self.training)
  199. class FeatureAlphaDropout(_DropoutNd):
  200. r"""Randomly masks out entire channels.
  201. A channel is a feature map,
  202. e.g. the :math:`j`-th channel of the :math:`i`-th sample in the batch input
  203. is a tensor :math:`\text{input}[i, j]` of the input tensor). Instead of
  204. setting activations to zero, as in regular Dropout, the activations are set
  205. to the negative saturation value of the SELU activation function. More details
  206. can be found in the paper `Self-Normalizing Neural Networks`_ .
  207. Each element will be masked independently for each sample on every forward
  208. call with probability :attr:`p` using samples from a Bernoulli distribution.
  209. The elements to be masked are randomized on every forward call, and scaled
  210. and shifted to maintain zero mean and unit variance.
  211. Usually the input comes from :class:`nn.AlphaDropout` modules.
  212. As described in the paper
  213. `Efficient Object Localization Using Convolutional Networks`_ ,
  214. if adjacent pixels within feature maps are strongly correlated
  215. (as is normally the case in early convolution layers) then i.i.d. dropout
  216. will not regularize the activations and will otherwise just result
  217. in an effective learning rate decrease.
  218. In this case, :func:`nn.AlphaDropout` will help promote independence between
  219. feature maps and should be used instead.
  220. Args:
  221. p (float, optional): probability of an element to be zeroed. Default: 0.5
  222. inplace (bool, optional): If set to ``True``, will do this operation
  223. in-place
  224. Shape:
  225. - Input: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)`.
  226. - Output: :math:`(N, C, D, H, W)` or :math:`(C, D, H, W)` (same shape as input).
  227. Examples::
  228. >>> m = nn.FeatureAlphaDropout(p=0.2)
  229. >>> input = torch.randn(20, 16, 4, 32, 32)
  230. >>> output = m(input)
  231. .. _Self-Normalizing Neural Networks: https://arxiv.org/abs/1706.02515
  232. .. _Efficient Object Localization Using Convolutional Networks:
  233. https://arxiv.org/abs/1411.4280
  234. """
  235. def forward(self, input: Tensor) -> Tensor:
  236. """
  237. Runs the forward pass.
  238. """
  239. return F.feature_alpha_dropout(input, self.p, self.training)