| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- from typing import Any, Dict, Tuple
- import torch
- from torch import nn
- from kornia.core import Module, Tensor
- def conv_1x1_bn(inp: int, oup: int) -> Module:
- """Apply 1x1 Convolution with Batch Norm."""
- return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU())
- def conv_nxn_bn(inp: int, oup: int, kernal_size: int = 3, stride: int = 1) -> Module:
- """Apply NxN Convolution with Batch Norm."""
- return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.SiLU())
- class PreNorm(Module):
- def __init__(self, dim: int, fn: Module) -> None:
- super().__init__()
- self.norm = nn.LayerNorm(dim)
- self.fn = fn
- def forward(self, x: Tensor, **kwargs: Dict[str, Any]) -> Tensor:
- return self.fn(self.norm(x), **kwargs)
- class FeedForward(Module):
- def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
- super().__init__()
- self.net = nn.Sequential(
- nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
- )
- def forward(self, x: Tensor) -> Tensor:
- return self.net(x)
- class Attention(Module):
- def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0) -> None:
- super().__init__()
- inner_dim = dim_head * heads
- project_out = not (heads == 1 and dim_head == dim)
- self.heads = heads
- self.scale = dim_head**-0.5
- self.attend = nn.Softmax(dim=-1)
- self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
- self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity()
- def forward(self, x: Tensor) -> Tensor:
- qkv = self.to_qkv(x).chunk(3, dim=-1)
- b, p, n, hd = qkv[0].shape
- q, k, v = (t.reshape(b, p, n, self.heads, hd // self.heads).transpose(2, 3) for t in qkv)
- dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
- attn = self.attend(dots)
- out = torch.matmul(attn, v)
- out = out.transpose(2, 3).reshape(b, p, n, hd)
- return self.to_out(out)
- class Transformer(Module):
- """Transformer block described in ViT.
- Paper: https://arxiv.org/abs/2010.11929
- Based on: https://github.com/lucidrains/vit-pytorch
- Args:
- dim: input dimension.
- depth: depth for transformer block.
- heads: number of heads in multi-head attention layer.
- dim_head: head size.
- mlp_dim: dimension of the FeedForward layer.
- dropout: dropout ratio, defaults to 0.
- """
- def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0) -> None:
- super().__init__()
- self.layers = nn.ModuleList([])
- for _ in range(depth):
- self.layers.append(
- nn.ModuleList(
- [
- PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
- PreNorm(dim, FeedForward(dim, mlp_dim, dropout)),
- ]
- )
- )
- def forward(self, x: Tensor) -> Tensor:
- for attn, ff in self.layers:
- x = attn(x) + x
- x = ff(x) + x
- return x
- class MV2Block(Module):
- """MV2 block described in MobileNetV2.
- Paper: https://arxiv.org/pdf/1801.04381
- Based on: https://github.com/tonylins/pytorch-mobilenet-v2
- Args:
- inp: input channel.
- oup: output channel.
- stride: stride for convolution, defaults to 1, set to 2 if down-sample.
- expansion: expansion ratio for hidden dimension, defaults to 4.
- """
- def __init__(self, inp: int, oup: int, stride: int = 1, expansion: int = 4) -> None:
- super().__init__()
- self.stride = stride
- hidden_dim = int(inp * expansion)
- self.use_res_connect = self.stride == 1 and inp == oup
- if expansion == 1:
- self.conv = nn.Sequential(
- # depthwise
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pointwise
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- else:
- self.conv = nn.Sequential(
- # pointwise
- nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # depthwise
- nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
- nn.BatchNorm2d(hidden_dim),
- nn.SiLU(),
- # pointwise
- nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
- nn.BatchNorm2d(oup),
- )
- def forward(self, x: Tensor) -> Tensor:
- if self.use_res_connect:
- return x + self.conv(x)
- else:
- return self.conv(x)
- class MobileViTBlock(Module):
- """MobileViT block mentioned in MobileViT.
- Args:
- dim: input dimension of Transformer.
- depth: depth of Transformer.
- channel: input channel.
- kernel_size: kernel size.
- patch_size: patch size for folding and unfloding.
- mlp_dim: dimension of the FeedForward layer in Transformer.
- dropout: dropout ratio, defaults to 0.
- """
- def __init__(
- self,
- dim: int,
- depth: int,
- channel: int,
- kernel_size: int,
- patch_size: Tuple[int, int],
- mlp_dim: int,
- dropout: float = 0.0,
- ) -> None:
- super().__init__()
- self.ph, self.pw = patch_size
- self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
- self.conv2 = conv_1x1_bn(channel, dim)
- self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
- self.conv3 = conv_1x1_bn(dim, channel)
- self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
- def forward(self, x: Tensor) -> Tensor:
- y = x.clone()
- # Local representations
- x = self.conv1(x)
- x = self.conv2(x)
- b, d, h, w = x.shape
- nh, nw = h // self.ph, w // self.pw
- # Global representations
- # [b, d, h, w] -> [b * d * nh, nw, ph, pw]
- x = x.reshape(b * d * nh, self.ph, nw, self.pw).transpose(1, 2)
- # [b * d * nh, nw, ph, pw] -> [b, (ph pw), (nh nw), d]
- x = x.reshape(b, d, nh * nw, self.ph * self.pw).transpose(1, 3)
- x = self.transformer(x)
- # [b, (ph pw), (nh nw), d] -> [b * d * nh, nw, ph, pw]
- x = x.transpose(1, 3).reshape(b * d * nh, nw, self.ph, self.pw)
- # [b * d * nh, nw, ph, pw] -> [b, d, h, w]
- x = x.transpose(1, 2).reshape(b, d, h, w)
- # Fusion
- x = self.conv3(x)
- x = torch.cat((x, y), 1)
- x = self.conv4(x)
- return x
- class MobileViT(Module):
- """Module MobileViT. Default arguments is for MobileViT XXS.
- Paper: https://arxiv.org/abs/2110.02178
- Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
- Args:
- mode: 'xxs', 'xs' or 's', defaults to 'xxs'.
- in_channels: the number of channels for the input image.
- patch_size: image_size must be divisible by patch_size.
- dropout: dropout ratio in Transformer.
- Example:
- >>> img = torch.rand(1, 3, 256, 256)
- >>> mvit = MobileViT(mode='xxs')
- >>> mvit(img).shape
- torch.Size([1, 320, 8, 8])
- """
- def __init__(
- self, mode: str = "xxs", in_channels: int = 3, patch_size: Tuple[int, int] = (2, 2), dropout: float = 0.0
- ) -> None:
- super().__init__()
- if mode == "xxs":
- expansion = 2
- dims = [64, 80, 96]
- channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
- elif mode == "xs":
- expansion = 4
- dims = [96, 120, 144]
- channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
- elif mode == "s":
- expansion = 4
- dims = [144, 192, 240]
- channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
- kernel_size = 3
- depth = [2, 4, 3]
- self.conv1 = conv_nxn_bn(in_channels, channels[0], stride=2)
- self.mv2 = nn.ModuleList([])
- self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
- self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
- self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
- self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
- self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
- self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
- self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
- self.mvit = nn.ModuleList([])
- self.mvit.append(
- MobileViTBlock(dims[0], depth[0], channels[5], kernel_size, patch_size, int(dims[0] * 2), dropout=dropout)
- )
- self.mvit.append(
- MobileViTBlock(dims[1], depth[1], channels[7], kernel_size, patch_size, int(dims[1] * 4), dropout=dropout)
- )
- self.mvit.append(
- MobileViTBlock(dims[2], depth[2], channels[9], kernel_size, patch_size, int(dims[2] * 4), dropout=dropout)
- )
- self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
- def forward(self, x: Tensor) -> Tensor:
- x = self.conv1(x)
- x = self.mv2[0](x)
- x = self.mv2[1](x)
- x = self.mv2[2](x)
- x = self.mv2[3](x) # Repeat
- x = self.mv2[4](x)
- x = self.mvit[0](x)
- x = self.mv2[5](x)
- x = self.mvit[1](x)
- x = self.mv2[6](x)
- x = self.mvit[2](x)
- x = self.conv2(x)
- return x
|