swiglu_ffn.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from typing import Callable, Optional
  7. from torch import Tensor, nn
  8. import torch.nn.functional as F
  9. class SwiGLUFFN(nn.Module):
  10. def __init__(
  11. self,
  12. in_features: int,
  13. hidden_features: Optional[int] = None,
  14. out_features: Optional[int] = None,
  15. act_layer: Callable[..., nn.Module] = None,
  16. drop: float = 0.0,
  17. bias: bool = True,
  18. ) -> None:
  19. super().__init__()
  20. out_features = out_features or in_features
  21. hidden_features = hidden_features or in_features
  22. self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
  23. self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
  24. def forward(self, x: Tensor) -> Tensor:
  25. x12 = self.w12(x)
  26. x1, x2 = x12.chunk(2, dim=-1)
  27. hidden = F.silu(x1) * x2
  28. return self.w3(hidden)
  29. try:
  30. from xformers.ops import SwiGLU
  31. XFORMERS_AVAILABLE = True
  32. except ImportError:
  33. SwiGLU = SwiGLUFFN
  34. XFORMERS_AVAILABLE = False
  35. class SwiGLUFFNFused(SwiGLU):
  36. def __init__(
  37. self,
  38. in_features: int,
  39. hidden_features: Optional[int] = None,
  40. out_features: Optional[int] = None,
  41. act_layer: Callable[..., nn.Module] = None,
  42. drop: float = 0.0,
  43. bias: bool = True,
  44. ) -> None:
  45. out_features = out_features or in_features
  46. hidden_features = hidden_features or in_features
  47. hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
  48. super().__init__(
  49. in_features=in_features,
  50. hidden_features=hidden_features,
  51. out_features=out_features,
  52. bias=bias,
  53. )