patch_dropout.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. from typing import Optional, Tuple, Union
  2. import torch
  3. import torch.nn as nn
  4. def patch_dropout_forward(
  5. x: torch.Tensor,
  6. prob: float,
  7. num_prefix_tokens: int,
  8. ordered: bool,
  9. training: bool,
  10. ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  11. """
  12. Common forward logic for patch dropout.
  13. Args:
  14. x: Input tensor of shape (B, L, D)
  15. prob: Dropout probability
  16. num_prefix_tokens: Number of prefix tokens to preserve
  17. ordered: Whether to maintain patch order
  18. training: Whether in training mode
  19. Returns:
  20. Tuple of (output tensor, keep_indices or None)
  21. """
  22. if not training or prob == 0.:
  23. return x, None
  24. if num_prefix_tokens:
  25. prefix_tokens, x = x[:, :num_prefix_tokens], x[:, num_prefix_tokens:]
  26. else:
  27. prefix_tokens = None
  28. B = x.shape[0]
  29. L = x.shape[1]
  30. num_keep = max(1, int(L * (1. - prob)))
  31. keep_indices = torch.argsort(torch.randn(B, L, device=x.device), dim=-1)[:, :num_keep]
  32. if ordered:
  33. # NOTE does not need to maintain patch order in typical transformer use,
  34. # but possibly useful for debug / visualization
  35. keep_indices = keep_indices.sort(dim=-1)[0]
  36. x = x.gather(1, keep_indices.unsqueeze(-1).expand((-1, -1) + x.shape[2:]))
  37. if prefix_tokens is not None:
  38. x = torch.cat((prefix_tokens, x), dim=1)
  39. return x, keep_indices
  40. class PatchDropout(nn.Module):
  41. """
  42. Patch Dropout without returning indices.
  43. https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
  44. """
  45. def __init__(
  46. self,
  47. prob: float = 0.5,
  48. num_prefix_tokens: int = 1,
  49. ordered: bool = False,
  50. ):
  51. super().__init__()
  52. assert 0 <= prob < 1.
  53. self.prob = prob
  54. self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
  55. self.ordered = ordered
  56. def forward(self, x: torch.Tensor) -> torch.Tensor:
  57. output, _ = patch_dropout_forward(
  58. x,
  59. self.prob,
  60. self.num_prefix_tokens,
  61. self.ordered,
  62. self.training
  63. )
  64. return output
  65. class PatchDropoutWithIndices(nn.Module):
  66. """
  67. Patch Dropout that returns both output and keep indices.
  68. https://arxiv.org/abs/2212.00794 and https://arxiv.org/pdf/2208.07220
  69. """
  70. def __init__(
  71. self,
  72. prob: float = 0.5,
  73. num_prefix_tokens: int = 1,
  74. ordered: bool = False,
  75. ):
  76. super().__init__()
  77. assert 0 <= prob < 1.
  78. self.prob = prob
  79. self.num_prefix_tokens = num_prefix_tokens # exclude CLS token (or other prefix tokens)
  80. self.ordered = ordered
  81. def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
  82. return patch_dropout_forward(
  83. x,
  84. self.prob,
  85. self.num_prefix_tokens,
  86. self.ordered,
  87. self.training
  88. )