| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290 |
- """ MLP module w/ dropout and configurable activation layer
- Hacked together by / Copyright 2020 Ross Wightman
- """
- from functools import partial
- from typing import Optional, Type, Union, Tuple
- from torch import nn as nn
- from .grn import GlobalResponseNorm
- from .helpers import to_2tuple
- class Mlp(nn.Module):
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
- NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Optional[Type[nn.Module]] = None,
- bias: Union[bool, Tuple[bool, bool]] = True,
- drop: Union[float, Tuple[float, float]] = 0.,
- use_conv: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
- self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1], **dd)
- self.drop2 = nn.Dropout(drop_probs[1])
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.norm(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
- class GluMlp(nn.Module):
- """ MLP w/ GLU style gating
- See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
- NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected.
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.Sigmoid,
- norm_layer: Optional[Type[nn.Module]] = None,
- bias: Union[bool, Tuple[bool, bool]] = True,
- drop: Union[float, Tuple[float, float]] = 0.,
- use_conv: bool = False,
- gate_last: bool = True,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- assert hidden_features % 2 == 0
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
- self.chunk_dim = 1 if use_conv else -1
- self.gate_last = gate_last # use second half of width for gate
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = norm_layer(hidden_features // 2, **dd) if norm_layer is not None else nn.Identity()
- self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1], **dd)
- self.drop2 = nn.Dropout(drop_probs[1])
- def init_weights(self):
- # override init of fc1 w/ gate portion set to weight near zero, bias=1
- if self.fc1.bias is not None:
- nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:])
- nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6)
- def forward(self, x):
- x = self.fc1(x)
- x1, x2 = x.chunk(2, dim=self.chunk_dim)
- x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2
- x = self.drop1(x)
- x = self.norm(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
- SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False)
- class SwiGLU(nn.Module):
- """ SwiGLU
- NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and
- better matches some other common impl which makes mapping checkpoints simpler.
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.SiLU,
- norm_layer: Optional[Type[nn.Module]] = None,
- bias: Union[bool, Tuple[bool, bool]] = True,
- drop: Union[float, Tuple[float, float]] = 0.,
- align_to: int = 0,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- if align_to:
- hidden_features = hidden_features + (-hidden_features % align_to)
- self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0], **dd)
- self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0], **dd)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1], **dd)
- self.drop2 = nn.Dropout(drop_probs[1])
- def init_weights(self):
- # override init of fc1 w/ gate portion set to weight near zero, bias=1
- if self.fc1_g.bias is not None:
- nn.init.ones_(self.fc1_g.bias)
- nn.init.normal_(self.fc1_g.weight, std=1e-6)
- def forward(self, x):
- x_gate = self.fc1_g(x)
- x = self.fc1_x(x)
- x = self.act(x_gate) * x
- x = self.drop1(x)
- x = self.norm(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
- class GatedMlp(nn.Module):
- """ MLP as used in gMLP
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.GELU,
- norm_layer: Optional[Type[nn.Module]] = None,
- gate_layer: Optional[Type[nn.Module]] = None,
- bias: Union[bool, Tuple[bool, bool]] = True,
- drop: Union[float, Tuple[float, float]] = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0], **dd)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- if gate_layer is not None:
- assert hidden_features % 2 == 0
- self.gate = gate_layer(hidden_features, **dd)
- hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
- else:
- self.gate = nn.Identity()
- self.norm = norm_layer(hidden_features, **dd) if norm_layer is not None else nn.Identity()
- self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1], **dd)
- self.drop2 = nn.Dropout(drop_probs[1])
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.gate(x)
- x = self.norm(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
- class ConvMlp(nn.Module):
- """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors)
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.ReLU,
- norm_layer: Optional[Type[nn.Module]] = None,
- bias: Union[bool, Tuple[bool, bool]] = True,
- drop: float = 0.,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0], **dd)
- self.norm = norm_layer(hidden_features, **dd) if norm_layer else nn.Identity()
- self.act = act_layer()
- self.drop = nn.Dropout(drop)
- self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1], **dd)
- def forward(self, x):
- x = self.fc1(x)
- x = self.norm(x)
- x = self.act(x)
- x = self.drop(x)
- x = self.fc2(x)
- return x
- class GlobalResponseNormMlp(nn.Module):
- """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
- NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts
- """
- def __init__(
- self,
- in_features: int,
- hidden_features: Optional[int] = None,
- out_features: Optional[int] = None,
- act_layer: Type[nn.Module] = nn.GELU,
- bias: Union[bool, Tuple[bool, bool]] = True,
- drop: Union[float, Tuple[float, float]] = 0.,
- use_conv: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0], **dd)
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv, **dd)
- self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1], **dd)
- self.drop2 = nn.Dropout(drop_probs[1])
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.grn(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
|