vit_mobile.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. from typing import Any, Dict, Tuple
  18. import torch
  19. from torch import nn
  20. from kornia.core import Module, Tensor
  21. def conv_1x1_bn(inp: int, oup: int) -> Module:
  22. """Apply 1x1 Convolution with Batch Norm."""
  23. return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.SiLU())
  24. def conv_nxn_bn(inp: int, oup: int, kernal_size: int = 3, stride: int = 1) -> Module:
  25. """Apply NxN Convolution with Batch Norm."""
  26. return nn.Sequential(nn.Conv2d(inp, oup, kernal_size, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.SiLU())
  27. class PreNorm(Module):
  28. def __init__(self, dim: int, fn: Module) -> None:
  29. super().__init__()
  30. self.norm = nn.LayerNorm(dim)
  31. self.fn = fn
  32. def forward(self, x: Tensor, **kwargs: Dict[str, Any]) -> Tensor:
  33. return self.fn(self.norm(x), **kwargs)
  34. class FeedForward(Module):
  35. def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0) -> None:
  36. super().__init__()
  37. self.net = nn.Sequential(
  38. nn.Linear(dim, hidden_dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
  39. )
  40. def forward(self, x: Tensor) -> Tensor:
  41. return self.net(x)
  42. class Attention(Module):
  43. def __init__(self, dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0) -> None:
  44. super().__init__()
  45. inner_dim = dim_head * heads
  46. project_out = not (heads == 1 and dim_head == dim)
  47. self.heads = heads
  48. self.scale = dim_head**-0.5
  49. self.attend = nn.Softmax(dim=-1)
  50. self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
  51. self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) if project_out else nn.Identity()
  52. def forward(self, x: Tensor) -> Tensor:
  53. qkv = self.to_qkv(x).chunk(3, dim=-1)
  54. b, p, n, hd = qkv[0].shape
  55. q, k, v = (t.reshape(b, p, n, self.heads, hd // self.heads).transpose(2, 3) for t in qkv)
  56. dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  57. attn = self.attend(dots)
  58. out = torch.matmul(attn, v)
  59. out = out.transpose(2, 3).reshape(b, p, n, hd)
  60. return self.to_out(out)
  61. class Transformer(Module):
  62. """Transformer block described in ViT.
  63. Paper: https://arxiv.org/abs/2010.11929
  64. Based on: https://github.com/lucidrains/vit-pytorch
  65. Args:
  66. dim: input dimension.
  67. depth: depth for transformer block.
  68. heads: number of heads in multi-head attention layer.
  69. dim_head: head size.
  70. mlp_dim: dimension of the FeedForward layer.
  71. dropout: dropout ratio, defaults to 0.
  72. """
  73. def __init__(self, dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0) -> None:
  74. super().__init__()
  75. self.layers = nn.ModuleList([])
  76. for _ in range(depth):
  77. self.layers.append(
  78. nn.ModuleList(
  79. [
  80. PreNorm(dim, Attention(dim, heads, dim_head, dropout)),
  81. PreNorm(dim, FeedForward(dim, mlp_dim, dropout)),
  82. ]
  83. )
  84. )
  85. def forward(self, x: Tensor) -> Tensor:
  86. for attn, ff in self.layers:
  87. x = attn(x) + x
  88. x = ff(x) + x
  89. return x
  90. class MV2Block(Module):
  91. """MV2 block described in MobileNetV2.
  92. Paper: https://arxiv.org/pdf/1801.04381
  93. Based on: https://github.com/tonylins/pytorch-mobilenet-v2
  94. Args:
  95. inp: input channel.
  96. oup: output channel.
  97. stride: stride for convolution, defaults to 1, set to 2 if down-sample.
  98. expansion: expansion ratio for hidden dimension, defaults to 4.
  99. """
  100. def __init__(self, inp: int, oup: int, stride: int = 1, expansion: int = 4) -> None:
  101. super().__init__()
  102. self.stride = stride
  103. hidden_dim = int(inp * expansion)
  104. self.use_res_connect = self.stride == 1 and inp == oup
  105. if expansion == 1:
  106. self.conv = nn.Sequential(
  107. # depthwise
  108. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  109. nn.BatchNorm2d(hidden_dim),
  110. nn.SiLU(),
  111. # pointwise
  112. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  113. nn.BatchNorm2d(oup),
  114. )
  115. else:
  116. self.conv = nn.Sequential(
  117. # pointwise
  118. nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
  119. nn.BatchNorm2d(hidden_dim),
  120. nn.SiLU(),
  121. # depthwise
  122. nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
  123. nn.BatchNorm2d(hidden_dim),
  124. nn.SiLU(),
  125. # pointwise
  126. nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
  127. nn.BatchNorm2d(oup),
  128. )
  129. def forward(self, x: Tensor) -> Tensor:
  130. if self.use_res_connect:
  131. return x + self.conv(x)
  132. else:
  133. return self.conv(x)
  134. class MobileViTBlock(Module):
  135. """MobileViT block mentioned in MobileViT.
  136. Args:
  137. dim: input dimension of Transformer.
  138. depth: depth of Transformer.
  139. channel: input channel.
  140. kernel_size: kernel size.
  141. patch_size: patch size for folding and unfloding.
  142. mlp_dim: dimension of the FeedForward layer in Transformer.
  143. dropout: dropout ratio, defaults to 0.
  144. """
  145. def __init__(
  146. self,
  147. dim: int,
  148. depth: int,
  149. channel: int,
  150. kernel_size: int,
  151. patch_size: Tuple[int, int],
  152. mlp_dim: int,
  153. dropout: float = 0.0,
  154. ) -> None:
  155. super().__init__()
  156. self.ph, self.pw = patch_size
  157. self.conv1 = conv_nxn_bn(channel, channel, kernel_size)
  158. self.conv2 = conv_1x1_bn(channel, dim)
  159. self.transformer = Transformer(dim, depth, 4, 8, mlp_dim, dropout)
  160. self.conv3 = conv_1x1_bn(dim, channel)
  161. self.conv4 = conv_nxn_bn(2 * channel, channel, kernel_size)
  162. def forward(self, x: Tensor) -> Tensor:
  163. y = x.clone()
  164. # Local representations
  165. x = self.conv1(x)
  166. x = self.conv2(x)
  167. b, d, h, w = x.shape
  168. nh, nw = h // self.ph, w // self.pw
  169. # Global representations
  170. # [b, d, h, w] -> [b * d * nh, nw, ph, pw]
  171. x = x.reshape(b * d * nh, self.ph, nw, self.pw).transpose(1, 2)
  172. # [b * d * nh, nw, ph, pw] -> [b, (ph pw), (nh nw), d]
  173. x = x.reshape(b, d, nh * nw, self.ph * self.pw).transpose(1, 3)
  174. x = self.transformer(x)
  175. # [b, (ph pw), (nh nw), d] -> [b * d * nh, nw, ph, pw]
  176. x = x.transpose(1, 3).reshape(b * d * nh, nw, self.ph, self.pw)
  177. # [b * d * nh, nw, ph, pw] -> [b, d, h, w]
  178. x = x.transpose(1, 2).reshape(b, d, h, w)
  179. # Fusion
  180. x = self.conv3(x)
  181. x = torch.cat((x, y), 1)
  182. x = self.conv4(x)
  183. return x
  184. class MobileViT(Module):
  185. """Module MobileViT. Default arguments is for MobileViT XXS.
  186. Paper: https://arxiv.org/abs/2110.02178
  187. Based on: https://github.com/chinhsuanwu/mobilevit-pytorch
  188. Args:
  189. mode: 'xxs', 'xs' or 's', defaults to 'xxs'.
  190. in_channels: the number of channels for the input image.
  191. patch_size: image_size must be divisible by patch_size.
  192. dropout: dropout ratio in Transformer.
  193. Example:
  194. >>> img = torch.rand(1, 3, 256, 256)
  195. >>> mvit = MobileViT(mode='xxs')
  196. >>> mvit(img).shape
  197. torch.Size([1, 320, 8, 8])
  198. """
  199. def __init__(
  200. self, mode: str = "xxs", in_channels: int = 3, patch_size: Tuple[int, int] = (2, 2), dropout: float = 0.0
  201. ) -> None:
  202. super().__init__()
  203. if mode == "xxs":
  204. expansion = 2
  205. dims = [64, 80, 96]
  206. channels = [16, 16, 24, 24, 48, 48, 64, 64, 80, 80, 320]
  207. elif mode == "xs":
  208. expansion = 4
  209. dims = [96, 120, 144]
  210. channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384]
  211. elif mode == "s":
  212. expansion = 4
  213. dims = [144, 192, 240]
  214. channels = [16, 32, 64, 64, 96, 96, 128, 128, 160, 160, 640]
  215. kernel_size = 3
  216. depth = [2, 4, 3]
  217. self.conv1 = conv_nxn_bn(in_channels, channels[0], stride=2)
  218. self.mv2 = nn.ModuleList([])
  219. self.mv2.append(MV2Block(channels[0], channels[1], 1, expansion))
  220. self.mv2.append(MV2Block(channels[1], channels[2], 2, expansion))
  221. self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion))
  222. self.mv2.append(MV2Block(channels[2], channels[3], 1, expansion)) # Repeat
  223. self.mv2.append(MV2Block(channels[3], channels[4], 2, expansion))
  224. self.mv2.append(MV2Block(channels[5], channels[6], 2, expansion))
  225. self.mv2.append(MV2Block(channels[7], channels[8], 2, expansion))
  226. self.mvit = nn.ModuleList([])
  227. self.mvit.append(
  228. MobileViTBlock(dims[0], depth[0], channels[5], kernel_size, patch_size, int(dims[0] * 2), dropout=dropout)
  229. )
  230. self.mvit.append(
  231. MobileViTBlock(dims[1], depth[1], channels[7], kernel_size, patch_size, int(dims[1] * 4), dropout=dropout)
  232. )
  233. self.mvit.append(
  234. MobileViTBlock(dims[2], depth[2], channels[9], kernel_size, patch_size, int(dims[2] * 4), dropout=dropout)
  235. )
  236. self.conv2 = conv_1x1_bn(channels[-2], channels[-1])
  237. def forward(self, x: Tensor) -> Tensor:
  238. x = self.conv1(x)
  239. x = self.mv2[0](x)
  240. x = self.mv2[1](x)
  241. x = self.mv2[2](x)
  242. x = self.mv2[3](x) # Repeat
  243. x = self.mv2[4](x)
  244. x = self.mvit[0](x)
  245. x = self.mv2[5](x)
  246. x = self.mvit[1](x)
  247. x = self.mv2[6](x)
  248. x = self.mvit[2](x)
  249. x = self.conv2(x)
  250. return x