mlp.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334
  1. # Copyright (c) 2024, Tri Dao, Albert Gu.
  2. from torch import nn
  3. from torch.nn import functional as F
  4. class GatedMLP(nn.Module):
  5. def __init__(
  6. self,
  7. in_features,
  8. hidden_features=None,
  9. out_features=None,
  10. activation=F.silu,
  11. bias=False,
  12. multiple_of=128,
  13. device=None,
  14. dtype=None,
  15. ):
  16. factory_kwargs = {"device": device, "dtype": dtype}
  17. super().__init__()
  18. out_features = out_features if out_features is not None else in_features
  19. hidden_features = (
  20. hidden_features if hidden_features is not None else int(8 * in_features / 3)
  21. )
  22. hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of
  23. self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias, **factory_kwargs)
  24. self.activation = activation
  25. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, **factory_kwargs)
  26. def forward(self, x):
  27. y = self.fc1(x)
  28. y, gate = y.chunk(2, dim=-1)
  29. y = y * self.activation(gate)
  30. y = self.fc2(y)
  31. return y