mlp.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. """ MLP module w/ dropout and configurable activation layer
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. from functools import partial
  5. from typing import Optional, Type, Union, Tuple
  6. from torch import nn as nn
  7. from .grn import GlobalResponseNorm
  8. from .helpers import to_2tuple
  9. class Mlp(nn.Module):
  10. """ MLP as used in Vision Transformer, MLP-Mixer and related networks
  11. NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
  12. """
  13. def __init__(
  14. self,
  15. in_features: int,
  16. hidden_features: Optional[int] = None,
  17. out_features: Optional[int] = None,
  18. act_layer: Type[nn.Module] = nn.GELU,
  19. norm_layer: Optional[Type[nn.Module]] = None,
  20. bias: Union[bool, Tuple[bool, bool]] = True,
  21. drop: Union[float, Tuple[float, float]] = 0.,
  22. use_conv: bool = False,
  23. device=None,
  24. dtype=None,
  25. ):
  26. dd = {'device': device, 'dtype': dtype}
  27. super().__init__()
  28. out_features = out_features or in_features
  29. hidden_features = hidden_features or in_features
  30. bias = to_2tuple(bias)
  31. drop_probs = to_2tuple(drop)
  32. linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
  33. self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd)
  34. self.act = act_layer()
  35. self.drop1 = nn.Dropout(drop_probs[0])
  36. self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
  37. self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1], **dd)
  38. self.drop2 = nn.Dropout(drop_probs[1])
  39. def forward(self, x):
  40. x = self.fc1(x)
  41. x = self.act(x)
  42. x = self.drop1(x)
  43. x = self.norm(x)
  44. x = self.fc2(x)
  45. x = self.drop2(x)
  46. return x
  47. class GluMlp(nn.Module):
  48. """ MLP w/ GLU style gating
  49. See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
  50. NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
  51. """
  52. def __init__(
  53. self,
  54. in_features: int,
  55. hidden_features: Optional[int] = None,
  56. out_features: Optional[int] = None,
  57. act_layer: Type[nn.Module] = nn.Sigmoid,
  58. norm_layer: Optional[Type[nn.Module]] = None,
  59. bias: Union[bool, Tuple[bool, bool]] = True,
  60. drop: Union[float, Tuple[float, float]] = 0.,
  61. use_conv: bool = False,
  62. gate_last: bool = True,
  63. device=None,
  64. dtype=None,
  65. ):
  66. dd = {'device': device, 'dtype': dtype}
  67. super().__init__()
  68. out_features = out_features or in_features
  69. hidden_features = hidden_features or in_features
  70. assert hidden_features % 2 == 0
  71. bias = to_2tuple(bias)
  72. drop_probs = to_2tuple(drop)
  73. linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
  74. self.chunk_dim = 1 if use_conv else -1
  75. self.gate_last = gate_last # use second half of width for gate
  76. self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd)
  77. self.act = act_layer()
  78. self.drop1 = nn.Dropout(drop_probs[0])
  79. self.norm = norm_layer(hidden_features // 2, **dd) if norm_layer is not None else nn.Identity()
  80. self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1], **dd)
  81. self.drop2 = nn.Dropout(drop_probs[1])
  82. def init_weights(self):
  83. # override init of fc1 w/ gate portion set to weight near zero, bias=1
  84. if self.fc1.bias is not None:
  85. nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:])
  86. nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6)
  87. def forward(self, x):
  88. x = self.fc1(x)
  89. x1, x2 = x.chunk(2, dim=self.chunk_dim)
  90. x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
  91. x = self.drop1(x)
  92. x = self.norm(x)
  93. x = self.fc2(x)
  94. x = self.drop2(x)
  95. return x
  96. SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
  97. class SwiGLU(nn.Module):
  98. """ SwiGLU
  99. NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
  100. better matches some other common impl which makes mapping checkpoints simpler.
  101. """
  102. def __init__(
  103. self,
  104. in_features: int,
  105. hidden_features: Optional[int] = None,
  106. out_features: Optional[int] = None,
  107. act_layer: Type[nn.Module] = nn.SiLU,
  108. norm_layer: Optional[Type[nn.Module]] = None,
  109. bias: Union[bool, Tuple[bool, bool]] = True,
  110. drop: Union[float, Tuple[float, float]] = 0.,
  111. align_to: int = 0,
  112. device=None,
  113. dtype=None,
  114. ):
  115. dd = {'device': device, 'dtype': dtype}
  116. super().__init__()
  117. out_features = out_features or in_features
  118. hidden_features = hidden_features or in_features
  119. bias = to_2tuple(bias)
  120. drop_probs = to_2tuple(drop)
  121. if align_to:
  122. hidden_features = hidden_features + (-hidden_features % align_to)
  123. self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0], **dd)
  124. self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0], **dd)
  125. self.act = act_layer()
  126. self.drop1 = nn.Dropout(drop_probs[0])
  127. self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
  128. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1], **dd)
  129. self.drop2 = nn.Dropout(drop_probs[1])
  130. def init_weights(self):
  131. # override init of fc1 w/ gate portion set to weight near zero, bias=1
  132. if self.fc1_g.bias is not None:
  133. nn.init.ones_(self.fc1_g.bias)
  134. nn.init.normal_(self.fc1_g.weight, std=1e-6)
  135. def forward(self, x):
  136. x_gate = self.fc1_g(x)
  137. x = self.fc1_x(x)
  138. x = self.act(x_gate) * x
  139. x = self.drop1(x)
  140. x = self.norm(x)
  141. x = self.fc2(x)
  142. x = self.drop2(x)
  143. return x
  144. class GatedMlp(nn.Module):
  145. """ MLP as used in gMLP
  146. """
  147. def __init__(
  148. self,
  149. in_features: int,
  150. hidden_features: Optional[int] = None,
  151. out_features: Optional[int] = None,
  152. act_layer: Type[nn.Module] = nn.GELU,
  153. norm_layer: Optional[Type[nn.Module]] = None,
  154. gate_layer: Optional[Type[nn.Module]] = None,
  155. bias: Union[bool, Tuple[bool, bool]] = True,
  156. drop: Union[float, Tuple[float, float]] = 0.,
  157. device=None,
  158. dtype=None,
  159. ):
  160. dd = {'device': device, 'dtype': dtype}
  161. super().__init__()
  162. out_features = out_features or in_features
  163. hidden_features = hidden_features or in_features
  164. bias = to_2tuple(bias)
  165. drop_probs = to_2tuple(drop)
  166. self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0], **dd)
  167. self.act = act_layer()
  168. self.drop1 = nn.Dropout(drop_probs[0])
  169. if gate_layer is not None:
  170. assert hidden_features % 2 == 0
  171. self.gate = gate_layer(hidden_features, **dd)
  172. hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
  173. else:
  174. self.gate = nn.Identity()
  175. self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
  176. self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1], **dd)
  177. self.drop2 = nn.Dropout(drop_probs[1])
  178. def forward(self, x):
  179. x = self.fc1(x)
  180. x = self.act(x)
  181. x = self.drop1(x)
  182. x = self.gate(x)
  183. x = self.norm(x)
  184. x = self.fc2(x)
  185. x = self.drop2(x)
  186. return x
  187. class ConvMlp(nn.Module):
  188. """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors)
  189. """
  190. def __init__(
  191. self,
  192. in_features: int,
  193. hidden_features: Optional[int] = None,
  194. out_features: Optional[int] = None,
  195. act_layer: Type[nn.Module] = nn.ReLU,
  196. norm_layer: Optional[Type[nn.Module]] = None,
  197. bias: Union[bool, Tuple[bool, bool]] = True,
  198. drop: float = 0.,
  199. device=None,
  200. dtype=None,
  201. ):
  202. dd = {'device': device, 'dtype': dtype}
  203. super().__init__()
  204. out_features = out_features or in_features
  205. hidden_features = hidden_features or in_features
  206. bias = to_2tuple(bias)
  207. self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0], **dd)
  208. self.norm = norm_layer(hidden_features, **dd) if norm_layer else nn.Identity()
  209. self.act = act_layer()
  210. self.drop = nn.Dropout(drop)
  211. self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1], **dd)
  212. def forward(self, x):
  213. x = self.fc1(x)
  214. x = self.norm(x)
  215. x = self.act(x)
  216. x = self.drop(x)
  217. x = self.fc2(x)
  218. return x
  219. class GlobalResponseNormMlp(nn.Module):
  220. """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
  221. NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts
  222. """
  223. def __init__(
  224. self,
  225. in_features: int,
  226. hidden_features: Optional[int] = None,
  227. out_features: Optional[int] = None,
  228. act_layer: Type[nn.Module] = nn.GELU,
  229. bias: Union[bool, Tuple[bool, bool]] = True,
  230. drop: Union[float, Tuple[float, float]] = 0.,
  231. use_conv: bool = False,
  232. device=None,
  233. dtype=None,
  234. ):
  235. dd = {'device': device, 'dtype': dtype}
  236. super().__init__()
  237. out_features = out_features or in_features
  238. hidden_features = hidden_features or in_features
  239. bias = to_2tuple(bias)
  240. drop_probs = to_2tuple(drop)
  241. linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
  242. self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd)
  243. self.act = act_layer()
  244. self.drop1 = nn.Dropout(drop_probs[0])
  245. self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv, **dd)
  246. self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1], **dd)
  247. self.drop2 = nn.Dropout(drop_probs[1])
  248. def forward(self, x):
  249. x = self.fc1(x)
  250. x = self.act(x)
  251. x = self.drop1(x)
  252. x = self.grn(x)
  253. x = self.fc2(x)
  254. x = self.drop2(x)
  255. return x