# LICENSE HEADER MANAGED BY add-license-header # # Copyright 2018 Kornia Team # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Module that implement Vision Transformer (ViT). Paper: https://paperswithcode.com/paper/an-image-is-worth-16x16-words-transformers-1 Based on: `https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632` Added some tricks from: `https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py` """ from __future__ import annotations from typing import Any, Callable import torch from torch import nn from kornia.core import Module, Tensor, concatenate from kornia.core.check import KORNIA_CHECK __all__ = ["VisionTransformer"] class ResidualAdd(Module): def __init__(self, fn: Callable[..., Tensor]) -> None: super().__init__() self.fn = fn def forward(self, x: Tensor, **kwargs: Any) -> Tensor: res = x x = self.fn(x, **kwargs) x += res return x class FeedForward(nn.Sequential): def __init__(self, in_features: int, hidden_features: int, out_features: int, dropout_rate: float = 0.0) -> None: super().__init__( nn.Linear(in_features, hidden_features), nn.GELU(), nn.Dropout(dropout_rate), nn.Linear(hidden_features, out_features), nn.Dropout(dropout_rate), # added one extra as in timm ) class MultiHeadAttention(Module): def __init__(self, emb_size: int, num_heads: int, att_drop: float, proj_drop: float) -> None: super().__init__() self.emb_size = emb_size self.num_heads = num_heads head_size = emb_size // num_heads # from timm self.scale = head_size**-0.5 # from timm if self.emb_size % self.num_heads: raise ValueError( f"Size of embedding inside the transformer decoder must be visible by number of heads" f"for correct multi-head attention " f"Got: {self.emb_size} embedding size and {self.num_heads} numbers of heads" ) # fuse the queries, keys and values in one matrix self.qkv = nn.Linear(emb_size, emb_size * 3) self.att_drop = nn.Dropout(att_drop) self.projection = nn.Linear(emb_size, emb_size) self.projection_drop = nn.Dropout(proj_drop) # added timm trick def forward(self, x: Tensor) -> Tensor: B, N, C = x.shape # split keys, queries and values in num_heads # NOTE: the line below differs from timm # timm: qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # sum up over the last axis att = torch.einsum("bhqd, bhkd -> bhqk", q, k) * self.scale att = att.softmax(dim=-1) att = self.att_drop(att) # sum up over the third axis out = torch.einsum("bhal, bhlv -> bhav ", att, v) out = out.permute(0, 2, 1, 3).contiguous().view(B, N, -1) out = self.projection(out) out = self.projection_drop(out) return out class TransformerEncoderBlock(nn.Sequential): def __init__(self, embed_dim: int, num_heads: int, dropout_rate: float, dropout_attn: float) -> None: super().__init__( ResidualAdd( nn.Sequential( nn.LayerNorm(embed_dim, 1e-6), MultiHeadAttention(embed_dim, num_heads, dropout_attn, dropout_rate), nn.Dropout(dropout_rate), ) ), ResidualAdd( nn.Sequential( nn.LayerNorm(embed_dim, 1e-6), FeedForward(embed_dim, embed_dim * 4, embed_dim, dropout_rate=dropout_rate), nn.Dropout(dropout_rate), ) ), ) class TransformerEncoder(Module): def __init__( self, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, dropout_rate: float = 0.0, dropout_attn: float = 0.0, ) -> None: super().__init__() self.blocks = nn.Sequential( *(TransformerEncoderBlock(embed_dim, num_heads, dropout_rate, dropout_attn) for _ in range(depth)) ) self.results: list[Tensor] = [] def forward(self, x: Tensor) -> Tensor: self.results = [] out = x for m in self.blocks.children(): out = m(out) self.results.append(out) return out class PatchEmbedding(Module): """Compute the 2d image patch embedding ready to pass to transformer encoder.""" def __init__( self, in_channels: int = 3, out_channels: int = 768, patch_size: int = 16, image_size: int = 224, backbone: Module | None = None, ) -> None: super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.patch_size = patch_size # logic needed in case a backbone is passed self.backbone = backbone or nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size) if backbone is not None: out_channels, feat_size = self._compute_feats_dims((in_channels, image_size, image_size)) self.out_channels = out_channels else: feat_size = (image_size // patch_size) ** 2 self.cls_token = nn.Parameter(torch.randn(1, 1, out_channels)) self.positions = nn.Parameter(torch.randn(feat_size + 1, out_channels)) def _compute_feats_dims(self, image_size: tuple[int, int, int]) -> tuple[int, int]: out = self.backbone(torch.zeros(1, *image_size)).detach() return out.shape[-3], out.shape[-2] * out.shape[-1] def forward(self, x: Tensor) -> Tensor: x = self.backbone(x) B, N, _, _ = x.shape x = x.view(B, N, -1).permute(0, 2, 1) # BxNxE cls_tokens = self.cls_token.repeat(B, 1, 1) # Bx1xE # prepend the cls token to the input x = concatenate([cls_tokens, x], dim=1) # Bx(N+1)xE # add position embedding x += self.positions return x class VisionTransformer(Module): """Vision transformer (ViT) module. The module is expected to be used as operator for different vision tasks. The method is inspired from existing implementations of the paper :cite:`dosovitskiy2020vit`. .. warning:: This is an experimental API subject to changes in favor of flexibility. Args: image_size: the size of the input image. patch_size: the size of the patch to compute the embedding. in_channels: the number of channels for the input. embed_dim: the embedding dimension inside the transformer encoder. depth: the depth of the transformer. num_heads: the number of attention heads. dropout_rate: dropout rate. dropout_attn: attention dropout rate. backbone: an nn.Module to compute the image patches embeddings. Example: >>> img = torch.rand(1, 3, 224, 224) >>> vit = VisionTransformer(image_size=224, patch_size=16) >>> vit(img).shape torch.Size([1, 197, 768]) """ def __init__( self, image_size: int = 224, patch_size: int = 16, in_channels: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, dropout_rate: float = 0.0, dropout_attn: float = 0.0, backbone: Module | None = None, ) -> None: super().__init__() self.image_size = image_size self.patch_size = patch_size self.in_channels = in_channels self.embed_size = embed_dim self.patch_embedding = PatchEmbedding(in_channels, embed_dim, patch_size, image_size, backbone) hidden_dim = self.patch_embedding.out_channels self.encoder = TransformerEncoder(hidden_dim, depth, num_heads, dropout_rate, dropout_attn) self.norm = nn.LayerNorm(hidden_dim, 1e-6) @property def encoder_results(self) -> list[Tensor]: return self.encoder.results def forward(self, x: Tensor) -> Tensor: if not isinstance(x, Tensor): raise TypeError(f"Input x type is not a Tensor. Got: {type(x)}") if self.image_size not in (*x.shape[-2:],) and x.shape[-3] != self.in_channels: raise ValueError( f"Input image shape must be Bx{self.in_channels}x{self.image_size}x{self.image_size}. Got: {x.shape}" ) out = self.patch_embedding(x) out = self.encoder(out) out = self.norm(out) return out @staticmethod def from_config(variant: str, pretrained: bool = False, **kwargs: Any) -> VisionTransformer: """Build ViT model based on the given config string. The format is ``vit_{size}/{patch_size}``. E.g. ``vit_b/16`` means ViT-Base, patch size 16x16. If ``pretrained=True``, AugReg weights are loaded. The weights are hosted on HuggingFace's model hub: https://huggingface.co/kornia. .. note:: The available weights are: ``vit_l/16``, ``vit_b/16``, ``vit_s/16``, ``vit_ti/16``, ``vit_b/32``, ``vit_s/32``. Args: variant: ViT model variant e.g. ``vit_b/16``. pretrained: whether to load pre-trained AugReg weights. kwargs: other keyword arguments that will be passed to :func:`kornia.contrib.vit.VisionTransformer`. Returns: The respective ViT model Example: >>> from kornia.contrib import VisionTransformer >>> vit_model = VisionTransformer.from_config("vit_b/16", pretrained=True) """ model_type, patch_size_str = variant.split("/") patch_size = int(patch_size_str) model_config = { "vit_ti": {"embed_dim": 192, "depth": 12, "num_heads": 3}, "vit_s": {"embed_dim": 384, "depth": 12, "num_heads": 6}, "vit_b": {"embed_dim": 768, "depth": 12, "num_heads": 12}, "vit_l": {"embed_dim": 1024, "depth": 24, "num_heads": 16}, "vit_h": {"embed_dim": 1280, "depth": 32, "num_heads": 16}, }[model_type] kwargs.update(model_config, patch_size=patch_size) model = VisionTransformer(**kwargs) if pretrained: url = _get_weight_url(variant) state_dict = torch.hub.load_state_dict_from_url(url) model.load_state_dict(state_dict) return model _AVAILABLE_WEIGHTS = ["vit_l/16", "vit_b/16", "vit_s/16", "vit_ti/16", "vit_b/32", "vit_s/32"] def _get_weight_url(variant: str) -> str: """Return the URL of the model weights.""" KORNIA_CHECK(variant in _AVAILABLE_WEIGHTS, f"Variant {variant} does not have pre-trained checkpoint") model_type, patch_size = variant.split("/") return f"https://huggingface.co/kornia/{model_type}{patch_size}_augreg_i21k_r224/resolve/main/{model_type}-{patch_size}.pth"