vit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # LICENSE HEADER MANAGED BY add-license-header
  2. #
  3. # Copyright 2018 Kornia Team
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. #
  17. """Module that implement Vision Transformer (ViT).
  18. Paper: https://paperswithcode.com/paper/an-image-is-worth-16x16-words-transformers-1
  19. Based on: `https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632`
  20. Added some tricks from: `https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py`
  21. """
  22. from __future__ import annotations
  23. from typing import Any, Callable
  24. import torch
  25. from torch import nn
  26. from kornia.core import Module, Tensor, concatenate
  27. from kornia.core.check import KORNIA_CHECK
  28. __all__ = ["VisionTransformer"]
  29. class ResidualAdd(Module):
  30. def __init__(self, fn: Callable[..., Tensor]) -> None:
  31. super().__init__()
  32. self.fn = fn
  33. def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
  34. res = x
  35. x = self.fn(x, **kwargs)
  36. x += res
  37. return x
  38. class FeedForward(nn.Sequential):
  39. def __init__(self, in_features: int, hidden_features: int, out_features: int, dropout_rate: float = 0.0) -> None:
  40. super().__init__(
  41. nn.Linear(in_features, hidden_features),
  42. nn.GELU(),
  43. nn.Dropout(dropout_rate),
  44. nn.Linear(hidden_features, out_features),
  45. nn.Dropout(dropout_rate), # added one extra as in timm
  46. )
  47. class MultiHeadAttention(Module):
  48. def __init__(self, emb_size: int, num_heads: int, att_drop: float, proj_drop: float) -> None:
  49. super().__init__()
  50. self.emb_size = emb_size
  51. self.num_heads = num_heads
  52. head_size = emb_size // num_heads # from timm
  53. self.scale = head_size**-0.5 # from timm
  54. if self.emb_size % self.num_heads:
  55. raise ValueError(
  56. f"Size of embedding inside the transformer decoder must be visible by number of heads"
  57. f"for correct multi-head attention "
  58. f"Got: {self.emb_size} embedding size and {self.num_heads} numbers of heads"
  59. )
  60. # fuse the queries, keys and values in one matrix
  61. self.qkv = nn.Linear(emb_size, emb_size * 3)
  62. self.att_drop = nn.Dropout(att_drop)
  63. self.projection = nn.Linear(emb_size, emb_size)
  64. self.projection_drop = nn.Dropout(proj_drop) # added timm trick
  65. def forward(self, x: Tensor) -> Tensor:
  66. B, N, C = x.shape
  67. # split keys, queries and values in num_heads
  68. # NOTE: the line below differs from timm
  69. # timm: qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  70. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  71. q, k, v = qkv[0], qkv[1], qkv[2]
  72. # sum up over the last axis
  73. att = torch.einsum("bhqd, bhkd -> bhqk", q, k) * self.scale
  74. att = att.softmax(dim=-1)
  75. att = self.att_drop(att)
  76. # sum up over the third axis
  77. out = torch.einsum("bhal, bhlv -> bhav ", att, v)
  78. out = out.permute(0, 2, 1, 3).contiguous().view(B, N, -1)
  79. out = self.projection(out)
  80. out = self.projection_drop(out)
  81. return out
  82. class TransformerEncoderBlock(nn.Sequential):
  83. def __init__(self, embed_dim: int, num_heads: int, dropout_rate: float, dropout_attn: float) -> None:
  84. super().__init__(
  85. ResidualAdd(
  86. nn.Sequential(
  87. nn.LayerNorm(embed_dim, 1e-6),
  88. MultiHeadAttention(embed_dim, num_heads, dropout_attn, dropout_rate),
  89. nn.Dropout(dropout_rate),
  90. )
  91. ),
  92. ResidualAdd(
  93. nn.Sequential(
  94. nn.LayerNorm(embed_dim, 1e-6),
  95. FeedForward(embed_dim, embed_dim * 4, embed_dim, dropout_rate=dropout_rate),
  96. nn.Dropout(dropout_rate),
  97. )
  98. ),
  99. )
  100. class TransformerEncoder(Module):
  101. def __init__(
  102. self,
  103. embed_dim: int = 768,
  104. depth: int = 12,
  105. num_heads: int = 12,
  106. dropout_rate: float = 0.0,
  107. dropout_attn: float = 0.0,
  108. ) -> None:
  109. super().__init__()
  110. self.blocks = nn.Sequential(
  111. *(TransformerEncoderBlock(embed_dim, num_heads, dropout_rate, dropout_attn) for _ in range(depth))
  112. )
  113. self.results: list[Tensor] = []
  114. def forward(self, x: Tensor) -> Tensor:
  115. self.results = []
  116. out = x
  117. for m in self.blocks.children():
  118. out = m(out)
  119. self.results.append(out)
  120. return out
  121. class PatchEmbedding(Module):
  122. """Compute the 2d image patch embedding ready to pass to transformer encoder."""
  123. def __init__(
  124. self,
  125. in_channels: int = 3,
  126. out_channels: int = 768,
  127. patch_size: int = 16,
  128. image_size: int = 224,
  129. backbone: Module | None = None,
  130. ) -> None:
  131. super().__init__()
  132. self.in_channels = in_channels
  133. self.out_channels = out_channels
  134. self.patch_size = patch_size
  135. # logic needed in case a backbone is passed
  136. self.backbone = backbone or nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
  137. if backbone is not None:
  138. out_channels, feat_size = self._compute_feats_dims((in_channels, image_size, image_size))
  139. self.out_channels = out_channels
  140. else:
  141. feat_size = (image_size // patch_size) ** 2
  142. self.cls_token = nn.Parameter(torch.randn(1, 1, out_channels))
  143. self.positions = nn.Parameter(torch.randn(feat_size + 1, out_channels))
  144. def _compute_feats_dims(self, image_size: tuple[int, int, int]) -> tuple[int, int]:
  145. out = self.backbone(torch.zeros(1, *image_size)).detach()
  146. return out.shape[-3], out.shape[-2] * out.shape[-1]
  147. def forward(self, x: Tensor) -> Tensor:
  148. x = self.backbone(x)
  149. B, N, _, _ = x.shape
  150. x = x.view(B, N, -1).permute(0, 2, 1) # BxNxE
  151. cls_tokens = self.cls_token.repeat(B, 1, 1) # Bx1xE
  152. # prepend the cls token to the input
  153. x = concatenate([cls_tokens, x], dim=1) # Bx(N+1)xE
  154. # add position embedding
  155. x += self.positions
  156. return x
  157. class VisionTransformer(Module):
  158. """Vision transformer (ViT) module.
  159. The module is expected to be used as operator for different vision tasks.
  160. The method is inspired from existing implementations of the paper :cite:`dosovitskiy2020vit`.
  161. .. warning::
  162. This is an experimental API subject to changes in favor of flexibility.
  163. Args:
  164. image_size: the size of the input image.
  165. patch_size: the size of the patch to compute the embedding.
  166. in_channels: the number of channels for the input.
  167. embed_dim: the embedding dimension inside the transformer encoder.
  168. depth: the depth of the transformer.
  169. num_heads: the number of attention heads.
  170. dropout_rate: dropout rate.
  171. dropout_attn: attention dropout rate.
  172. backbone: an nn.Module to compute the image patches embeddings.
  173. Example:
  174. >>> img = torch.rand(1, 3, 224, 224)
  175. >>> vit = VisionTransformer(image_size=224, patch_size=16)
  176. >>> vit(img).shape
  177. torch.Size([1, 197, 768])
  178. """
  179. def __init__(
  180. self,
  181. image_size: int = 224,
  182. patch_size: int = 16,
  183. in_channels: int = 3,
  184. embed_dim: int = 768,
  185. depth: int = 12,
  186. num_heads: int = 12,
  187. dropout_rate: float = 0.0,
  188. dropout_attn: float = 0.0,
  189. backbone: Module | None = None,
  190. ) -> None:
  191. super().__init__()
  192. self.image_size = image_size
  193. self.patch_size = patch_size
  194. self.in_channels = in_channels
  195. self.embed_size = embed_dim
  196. self.patch_embedding = PatchEmbedding(in_channels, embed_dim, patch_size, image_size, backbone)
  197. hidden_dim = self.patch_embedding.out_channels
  198. self.encoder = TransformerEncoder(hidden_dim, depth, num_heads, dropout_rate, dropout_attn)
  199. self.norm = nn.LayerNorm(hidden_dim, 1e-6)
  200. @property
  201. def encoder_results(self) -> list[Tensor]:
  202. return self.encoder.results
  203. def forward(self, x: Tensor) -> Tensor:
  204. if not isinstance(x, Tensor):
  205. raise TypeError(f"Input x type is not a Tensor. Got: {type(x)}")
  206. if self.image_size not in (*x.shape[-2:],) and x.shape[-3] != self.in_channels:
  207. raise ValueError(
  208. f"Input image shape must be Bx{self.in_channels}x{self.image_size}x{self.image_size}. Got: {x.shape}"
  209. )
  210. out = self.patch_embedding(x)
  211. out = self.encoder(out)
  212. out = self.norm(out)
  213. return out
  214. @staticmethod
  215. def from_config(variant: str, pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
  216. """Build ViT model based on the given config string.
  217. The format is ``vit_{size}/{patch_size}``.
  218. E.g. ``vit_b/16`` means ViT-Base, patch size 16x16. If ``pretrained=True``, AugReg weights are loaded.
  219. The weights are hosted on HuggingFace's model hub: https://huggingface.co/kornia.
  220. .. note::
  221. The available weights are: ``vit_l/16``, ``vit_b/16``, ``vit_s/16``, ``vit_ti/16``,
  222. ``vit_b/32``, ``vit_s/32``.
  223. Args:
  224. variant: ViT model variant e.g. ``vit_b/16``.
  225. pretrained: whether to load pre-trained AugReg weights.
  226. kwargs: other keyword arguments that will be passed to :func:`kornia.contrib.vit.VisionTransformer`.
  227. Returns:
  228. The respective ViT model
  229. Example:
  230. >>> from kornia.contrib import VisionTransformer
  231. >>> vit_model = VisionTransformer.from_config("vit_b/16", pretrained=True)
  232. """
  233. model_type, patch_size_str = variant.split("/")
  234. patch_size = int(patch_size_str)
  235. model_config = {
  236. "vit_ti": {"embed_dim": 192, "depth": 12, "num_heads": 3},
  237. "vit_s": {"embed_dim": 384, "depth": 12, "num_heads": 6},
  238. "vit_b": {"embed_dim": 768, "depth": 12, "num_heads": 12},
  239. "vit_l": {"embed_dim": 1024, "depth": 24, "num_heads": 16},
  240. "vit_h": {"embed_dim": 1280, "depth": 32, "num_heads": 16},
  241. }[model_type]
  242. kwargs.update(model_config, patch_size=patch_size)
  243. model = VisionTransformer(**kwargs)
  244. if pretrained:
  245. url = _get_weight_url(variant)
  246. state_dict = torch.hub.load_state_dict_from_url(url)
  247. model.load_state_dict(state_dict)
  248. return model
  249. _AVAILABLE_WEIGHTS = ["vit_l/16", "vit_b/16", "vit_s/16", "vit_ti/16", "vit_b/32", "vit_s/32"]
  250. def _get_weight_url(variant: str) -> str:
  251. """Return the URL of the model weights."""
  252. KORNIA_CHECK(variant in _AVAILABLE_WEIGHTS, f"Variant {variant} does not have pre-trained checkpoint")
  253. model_type, patch_size = variant.split("/")
  254. return f"https://huggingface.co/kornia/{model_type}{patch_size}_augreg_i21k_r224/resolve/main/{model_type}-{patch_size}.pth"