| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- # LICENSE HEADER MANAGED BY add-license-header
- #
- # Copyright 2018 Kornia Team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- """Module that implement Vision Transformer (ViT).
- Paper: https://paperswithcode.com/paper/an-image-is-worth-16x16-words-transformers-1
- Based on: `https://towardsdatascience.com/implementing-visualttransformer-in-pytorch-184f9f16f632`
- Added some tricks from: `https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py`
- """
- from __future__ import annotations
- from typing import Any, Callable
- import torch
- from torch import nn
- from kornia.core import Module, Tensor, concatenate
- from kornia.core.check import KORNIA_CHECK
- __all__ = ["VisionTransformer"]
- class ResidualAdd(Module):
- def __init__(self, fn: Callable[..., Tensor]) -> None:
- super().__init__()
- self.fn = fn
- def forward(self, x: Tensor, **kwargs: Any) -> Tensor:
- res = x
- x = self.fn(x, **kwargs)
- x += res
- return x
- class FeedForward(nn.Sequential):
- def __init__(self, in_features: int, hidden_features: int, out_features: int, dropout_rate: float = 0.0) -> None:
- super().__init__(
- nn.Linear(in_features, hidden_features),
- nn.GELU(),
- nn.Dropout(dropout_rate),
- nn.Linear(hidden_features, out_features),
- nn.Dropout(dropout_rate), # added one extra as in timm
- )
- class MultiHeadAttention(Module):
- def __init__(self, emb_size: int, num_heads: int, att_drop: float, proj_drop: float) -> None:
- super().__init__()
- self.emb_size = emb_size
- self.num_heads = num_heads
- head_size = emb_size // num_heads # from timm
- self.scale = head_size**-0.5 # from timm
- if self.emb_size % self.num_heads:
- raise ValueError(
- f"Size of embedding inside the transformer decoder must be visible by number of heads"
- f"for correct multi-head attention "
- f"Got: {self.emb_size} embedding size and {self.num_heads} numbers of heads"
- )
- # fuse the queries, keys and values in one matrix
- self.qkv = nn.Linear(emb_size, emb_size * 3)
- self.att_drop = nn.Dropout(att_drop)
- self.projection = nn.Linear(emb_size, emb_size)
- self.projection_drop = nn.Dropout(proj_drop) # added timm trick
- def forward(self, x: Tensor) -> Tensor:
- B, N, C = x.shape
- # split keys, queries and values in num_heads
- # NOTE: the line below differs from timm
- # timm: qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
- q, k, v = qkv[0], qkv[1], qkv[2]
- # sum up over the last axis
- att = torch.einsum("bhqd, bhkd -> bhqk", q, k) * self.scale
- att = att.softmax(dim=-1)
- att = self.att_drop(att)
- # sum up over the third axis
- out = torch.einsum("bhal, bhlv -> bhav ", att, v)
- out = out.permute(0, 2, 1, 3).contiguous().view(B, N, -1)
- out = self.projection(out)
- out = self.projection_drop(out)
- return out
- class TransformerEncoderBlock(nn.Sequential):
- def __init__(self, embed_dim: int, num_heads: int, dropout_rate: float, dropout_attn: float) -> None:
- super().__init__(
- ResidualAdd(
- nn.Sequential(
- nn.LayerNorm(embed_dim, 1e-6),
- MultiHeadAttention(embed_dim, num_heads, dropout_attn, dropout_rate),
- nn.Dropout(dropout_rate),
- )
- ),
- ResidualAdd(
- nn.Sequential(
- nn.LayerNorm(embed_dim, 1e-6),
- FeedForward(embed_dim, embed_dim * 4, embed_dim, dropout_rate=dropout_rate),
- nn.Dropout(dropout_rate),
- )
- ),
- )
- class TransformerEncoder(Module):
- def __init__(
- self,
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- dropout_rate: float = 0.0,
- dropout_attn: float = 0.0,
- ) -> None:
- super().__init__()
- self.blocks = nn.Sequential(
- *(TransformerEncoderBlock(embed_dim, num_heads, dropout_rate, dropout_attn) for _ in range(depth))
- )
- self.results: list[Tensor] = []
- def forward(self, x: Tensor) -> Tensor:
- self.results = []
- out = x
- for m in self.blocks.children():
- out = m(out)
- self.results.append(out)
- return out
- class PatchEmbedding(Module):
- """Compute the 2d image patch embedding ready to pass to transformer encoder."""
- def __init__(
- self,
- in_channels: int = 3,
- out_channels: int = 768,
- patch_size: int = 16,
- image_size: int = 224,
- backbone: Module | None = None,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.out_channels = out_channels
- self.patch_size = patch_size
- # logic needed in case a backbone is passed
- self.backbone = backbone or nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size)
- if backbone is not None:
- out_channels, feat_size = self._compute_feats_dims((in_channels, image_size, image_size))
- self.out_channels = out_channels
- else:
- feat_size = (image_size // patch_size) ** 2
- self.cls_token = nn.Parameter(torch.randn(1, 1, out_channels))
- self.positions = nn.Parameter(torch.randn(feat_size + 1, out_channels))
- def _compute_feats_dims(self, image_size: tuple[int, int, int]) -> tuple[int, int]:
- out = self.backbone(torch.zeros(1, *image_size)).detach()
- return out.shape[-3], out.shape[-2] * out.shape[-1]
- def forward(self, x: Tensor) -> Tensor:
- x = self.backbone(x)
- B, N, _, _ = x.shape
- x = x.view(B, N, -1).permute(0, 2, 1) # BxNxE
- cls_tokens = self.cls_token.repeat(B, 1, 1) # Bx1xE
- # prepend the cls token to the input
- x = concatenate([cls_tokens, x], dim=1) # Bx(N+1)xE
- # add position embedding
- x += self.positions
- return x
- class VisionTransformer(Module):
- """Vision transformer (ViT) module.
- The module is expected to be used as operator for different vision tasks.
- The method is inspired from existing implementations of the paper :cite:`dosovitskiy2020vit`.
- .. warning::
- This is an experimental API subject to changes in favor of flexibility.
- Args:
- image_size: the size of the input image.
- patch_size: the size of the patch to compute the embedding.
- in_channels: the number of channels for the input.
- embed_dim: the embedding dimension inside the transformer encoder.
- depth: the depth of the transformer.
- num_heads: the number of attention heads.
- dropout_rate: dropout rate.
- dropout_attn: attention dropout rate.
- backbone: an nn.Module to compute the image patches embeddings.
- Example:
- >>> img = torch.rand(1, 3, 224, 224)
- >>> vit = VisionTransformer(image_size=224, patch_size=16)
- >>> vit(img).shape
- torch.Size([1, 197, 768])
- """
- def __init__(
- self,
- image_size: int = 224,
- patch_size: int = 16,
- in_channels: int = 3,
- embed_dim: int = 768,
- depth: int = 12,
- num_heads: int = 12,
- dropout_rate: float = 0.0,
- dropout_attn: float = 0.0,
- backbone: Module | None = None,
- ) -> None:
- super().__init__()
- self.image_size = image_size
- self.patch_size = patch_size
- self.in_channels = in_channels
- self.embed_size = embed_dim
- self.patch_embedding = PatchEmbedding(in_channels, embed_dim, patch_size, image_size, backbone)
- hidden_dim = self.patch_embedding.out_channels
- self.encoder = TransformerEncoder(hidden_dim, depth, num_heads, dropout_rate, dropout_attn)
- self.norm = nn.LayerNorm(hidden_dim, 1e-6)
- @property
- def encoder_results(self) -> list[Tensor]:
- return self.encoder.results
- def forward(self, x: Tensor) -> Tensor:
- if not isinstance(x, Tensor):
- raise TypeError(f"Input x type is not a Tensor. Got: {type(x)}")
- if self.image_size not in (*x.shape[-2:],) and x.shape[-3] != self.in_channels:
- raise ValueError(
- f"Input image shape must be Bx{self.in_channels}x{self.image_size}x{self.image_size}. Got: {x.shape}"
- )
- out = self.patch_embedding(x)
- out = self.encoder(out)
- out = self.norm(out)
- return out
- @staticmethod
- def from_config(variant: str, pretrained: bool = False, **kwargs: Any) -> VisionTransformer:
- """Build ViT model based on the given config string.
- The format is ``vit_{size}/{patch_size}``.
- E.g. ``vit_b/16`` means ViT-Base, patch size 16x16. If ``pretrained=True``, AugReg weights are loaded.
- The weights are hosted on HuggingFace's model hub: https://huggingface.co/kornia.
- .. note::
- The available weights are: ``vit_l/16``, ``vit_b/16``, ``vit_s/16``, ``vit_ti/16``,
- ``vit_b/32``, ``vit_s/32``.
- Args:
- variant: ViT model variant e.g. ``vit_b/16``.
- pretrained: whether to load pre-trained AugReg weights.
- kwargs: other keyword arguments that will be passed to :func:`kornia.contrib.vit.VisionTransformer`.
- Returns:
- The respective ViT model
- Example:
- >>> from kornia.contrib import VisionTransformer
- >>> vit_model = VisionTransformer.from_config("vit_b/16", pretrained=True)
- """
- model_type, patch_size_str = variant.split("/")
- patch_size = int(patch_size_str)
- model_config = {
- "vit_ti": {"embed_dim": 192, "depth": 12, "num_heads": 3},
- "vit_s": {"embed_dim": 384, "depth": 12, "num_heads": 6},
- "vit_b": {"embed_dim": 768, "depth": 12, "num_heads": 12},
- "vit_l": {"embed_dim": 1024, "depth": 24, "num_heads": 16},
- "vit_h": {"embed_dim": 1280, "depth": 32, "num_heads": 16},
- }[model_type]
- kwargs.update(model_config, patch_size=patch_size)
- model = VisionTransformer(**kwargs)
- if pretrained:
- url = _get_weight_url(variant)
- state_dict = torch.hub.load_state_dict_from_url(url)
- model.load_state_dict(state_dict)
- return model
- _AVAILABLE_WEIGHTS = ["vit_l/16", "vit_b/16", "vit_s/16", "vit_ti/16", "vit_b/32", "vit_s/32"]
- def _get_weight_url(variant: str) -> str:
- """Return the URL of the model weights."""
- KORNIA_CHECK(variant in _AVAILABLE_WEIGHTS, f"Variant {variant} does not have pre-trained checkpoint")
- model_type, patch_size = variant.split("/")
- return f"https://huggingface.co/kornia/{model_type}{patch_size}_augreg_i21k_r224/resolve/main/{model_type}-{patch_size}.pth"
|