| 12345678910111213141516171819202122232425262728293031323334 |
- # Copyright (c) 2024, Tri Dao, Albert Gu.
- from torch import nn
- from torch.nn import functional as F
- class GatedMLP(nn.Module):
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- activation=F.silu,
- bias=False,
- multiple_of=128,
- device=None,
- dtype=None,
- ):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- out_features = out_features if out_features is not None else in_features
- hidden_features = (
- hidden_features if hidden_features is not None else int(8 * in_features / 3)
- )
- hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
- self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
- self.activation = activation
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
- def forward(self, x):
- y = self.fc1(x)
- y, gate = y.chunk(2, dim=-1)
- y = y * self.activation(gate)
- y = self.fc2(y)
- return y
|