mlp.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # References:
  7. # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
  8. # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
  9. from typing import Callable, Optional
  10. from torch import Tensor, nn
  11. class Mlp(nn.Module):
  12. def __init__(
  13. self,
  14. in_features: int,
  15. hidden_features: Optional[int] = None,
  16. out_features: Optional[int] = None,
  17. act_layer: Callable[..., nn.Module] = nn.GELU,
  18. drop: float = 0.0,
  19. bias: bool = True,
  20. ) -> None:
  21. super().__init__()
  22. out_features = out_features or in_features
  23. hidden_features = hidden_features or in_features
  24. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
  25. self.act = act_layer()
  26. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
  27. self.drop = nn.Dropout(drop)
  28. def forward(self, x: Tensor) -> Tensor:
  29. x = self.fc1(x)
  30. x = self.act(x)
  31. x = self.drop(x)
  32. x = self.fc2(x)
  33. x = self.drop(x)
  34. return x