| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- from typing import Callable, Optional
- from torch import Tensor, nn
- import torch.nn.functional as F
- class SwiGLUFFN(nn.Module):
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Callable[..., nn.Module] = None,
- drop: float = 0.0,
- bias: bool = True,
- ) -> None:
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
- self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
- def forward(self, x: Tensor) -> Tensor:
- x12 = self.w12(x)
- x1, x2 = x12.chunk(2, dim=-1)
- hidden = F.silu(x1) * x2
- return self.w3(hidden)
- try:
- from xformers.ops import SwiGLU
- XFORMERS_AVAILABLE = True
- except ImportError:
- SwiGLU = SwiGLUFFN
- XFORMERS_AVAILABLE = False
- class SwiGLUFFNFused(SwiGLU):
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Callable[..., nn.Module] = None,
- drop: float = 0.0,
- bias: bool = True,
- ) -> None:
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
- super().__init__(
- in_features=in_features,
- hidden_features=hidden_features,
- out_features=out_features,
- bias=bias,
- )
|