layer_scale.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. import torch
  2. from torch import nn
  3. class LayerScale(nn.Module):
  4. """ LayerScale on tensors with channels in last-dim.
  5. """
  6. def __init__(
  7. self,
  8. dim: int,
  9. init_values: float = 1e-5,
  10. inplace: bool = False,
  11. device=None,
  12. dtype=None,
  13. ) -> None:
  14. super().__init__()
  15. self.init_values = init_values
  16. self.inplace = inplace
  17. self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
  18. self.reset_parameters()
  19. def reset_parameters(self):
  20. torch.nn.init.constant_(self.gamma, self.init_values)
  21. def forward(self, x: torch.Tensor) -> torch.Tensor:
  22. return x.mul_(self.gamma) if self.inplace else x * self.gamma
  23. class LayerScale2d(nn.Module):
  24. """ LayerScale for tensors with torch 2D NCHW layout.
  25. """
  26. def __init__(
  27. self,
  28. dim: int,
  29. init_values: float = 1e-5,
  30. inplace: bool = False,
  31. device=None,
  32. dtype=None,
  33. ):
  34. super().__init__()
  35. self.init_values = init_values
  36. self.inplace = inplace
  37. self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
  38. self.reset_parameters()
  39. def reset_parameters(self):
  40. torch.nn.init.constant_(self.gamma, self.init_values)
  41. def forward(self, x):
  42. gamma = self.gamma.view(1, -1, 1, 1)
  43. return x.mul_(gamma) if self.inplace else x * gamma