| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648 |
- """ Image to Patch Embedding using Conv2d
- A convolution based approach to patchifying a 2D image w/ embedding projection.
- Based on code in:
- * https://github.com/google-research/vision_transformer
- * https://github.com/google-research/big_vision/tree/main/big_vision
- Hacked together by / Copyright 2020 Ross Wightman
- """
- import logging
- import math
- from typing import Callable, Dict, List, Optional, Tuple, Union
- import torch
- from torch import nn as nn
- import torch.nn.functional as F
- from .format import Format, nchw_to
- from .helpers import to_2tuple
- from .trace_utils import _assert
- _logger = logging.getLogger(__name__)
- class PatchEmbed(nn.Module):
- """ 2D Image to Patch Embedding
- """
- output_fmt: Format
- dynamic_img_pad: torch.jit.Final[bool]
- def __init__(
- self,
- img_size: Optional[Union[int, Tuple[int, int]]] = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer: Optional[Callable] = None,
- flatten: bool = True,
- output_fmt: Optional[str] = None,
- bias: bool = True,
- strict_img_size: bool = True,
- dynamic_img_pad: bool = False,
- device=None,
- dtype=None,
- ):
- dd = {'device': device, 'dtype': dtype}
- super().__init__()
- self.patch_size = to_2tuple(patch_size)
- self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
- if output_fmt is not None:
- self.flatten = False
- self.output_fmt = Format(output_fmt)
- else:
- # flatten spatial dim and transpose to channels last, kept for bwd compat
- self.flatten = flatten
- self.output_fmt = Format.NCHW
- self.strict_img_size = strict_img_size
- self.dynamic_img_pad = dynamic_img_pad
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **dd)
- self.norm = norm_layer(embed_dim, **dd) if norm_layer else nn.Identity()
- def _init_img_size(self, img_size: Union[int, Tuple[int, int]]):
- assert self.patch_size
- if img_size is None:
- return None, None, None
- img_size = to_2tuple(img_size)
- grid_size = tuple([s // p for s, p in zip(img_size, self.patch_size)])
- num_patches = grid_size[0] * grid_size[1]
- return img_size, grid_size, num_patches
- def set_input_size(
- self,
- img_size: Optional[Union[int, Tuple[int, int]]] = None,
- patch_size: Optional[Union[int, Tuple[int, int]]] = None,
- ):
- new_patch_size = None
- if patch_size is not None:
- new_patch_size = to_2tuple(patch_size)
- if new_patch_size is not None and new_patch_size != self.patch_size:
- with torch.no_grad():
- new_proj = nn.Conv2d(
- self.proj.in_channels,
- self.proj.out_channels,
- kernel_size=new_patch_size,
- stride=new_patch_size,
- bias=self.proj.bias is not None,
- device=self.proj.weight.device,
- dtype=self.proj.weight.dtype,
- )
- new_proj.weight.copy_(resample_patch_embed(self.proj.weight, new_patch_size, verbose=True))
- if self.proj.bias is not None:
- new_proj.bias.copy_(self.proj.bias)
- self.proj = new_proj
- self.patch_size = new_patch_size
- img_size = img_size or self.img_size
- if img_size != self.img_size or new_patch_size is not None:
- self.img_size, self.grid_size, self.num_patches = self._init_img_size(img_size)
- def feat_ratio(self, as_scalar=True) -> Union[Tuple[int, int], int]:
- if as_scalar:
- return max(self.patch_size)
- else:
- return self.patch_size
- def dynamic_feat_size(self, img_size: Tuple[int, int]) -> Tuple[int, int]:
- """ Get grid (feature) size for given image size taking account of dynamic padding.
- NOTE: must be torchscript compatible so using fixed tuple indexing
- """
- if self.dynamic_img_pad:
- return math.ceil(img_size[0] / self.patch_size[0]), math.ceil(img_size[1] / self.patch_size[1])
- else:
- return img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]
- def forward(self, x):
- B, C, H, W = x.shape
- if self.img_size is not None:
- if self.strict_img_size:
- _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
- _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
- elif not self.dynamic_img_pad:
- _assert(
- H % self.patch_size[0] == 0,
- f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
- )
- _assert(
- W % self.patch_size[1] == 0,
- f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
- )
- if self.dynamic_img_pad:
- pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
- pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
- x = F.pad(x, (0, pad_w, 0, pad_h))
- x = self.proj(x)
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
- elif self.output_fmt != Format.NCHW:
- x = nchw_to(x, self.output_fmt)
- x = self.norm(x)
- return x
- class PatchEmbedWithSize(PatchEmbed):
- """ 2D Image to Patch Embedding
- """
- output_fmt: Format
- def __init__(
- self,
- img_size: Optional[Union[int, Tuple[int, int]]] = 224,
- patch_size: int = 16,
- in_chans: int = 3,
- embed_dim: int = 768,
- norm_layer: Optional[Callable] = None,
- flatten: bool = True,
- output_fmt: Optional[str] = None,
- bias: bool = True,
- device=None,
- dtype=None,
- ):
- super().__init__(
- img_size=img_size,
- patch_size=patch_size,
- in_chans=in_chans,
- embed_dim=embed_dim,
- norm_layer=norm_layer,
- flatten=flatten,
- output_fmt=output_fmt,
- bias=bias,
- device=device,
- dtype=dtype,
- )
- def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
- B, C, H, W = x.shape
- if self.img_size is not None:
- _assert(H % self.patch_size[0] == 0, f"Input image height ({H}) must be divisible by patch size ({self.patch_size[0]}).")
- _assert(W % self.patch_size[1] == 0, f"Input image width ({W}) must be divisible by patch size ({self.patch_size[1]}).")
- x = self.proj(x)
- feat_size = x.shape[-2:]
- if self.flatten:
- x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
- elif self.output_fmt != Format.NCHW:
- x = nchw_to(x, self.output_fmt)
- x = self.norm(x)
- return x, feat_size
- # FIXME to remove, keeping for comparison for now
- def resample_patch_embed_old(
- patch_embed,
- new_size: List[int],
- interpolation: str = 'bicubic',
- antialias: bool = True,
- verbose: bool = False,
- ):
- """Resample the weights of the patch embedding kernel to target resolution.
- We resample the patch embedding kernel by approximately inverting the effect
- of patch resizing.
- Code based on:
- https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
- With this resizing, we can for example load a B/8 filter into a B/16 model
- and, on 2x larger input image, the result will match.
- Args:
- patch_embed: original parameter to be resized.
- new_size (tuple(int, int): target shape (height, width)-only.
- interpolation (str): interpolation for resize
- antialias (bool): use anti-aliasing filter in resize
- verbose (bool): log operation
- Returns:
- Resized patch embedding kernel.
- """
- import numpy as np
- try:
- from torch import vmap
- except ImportError:
- from functorch import vmap
- assert len(patch_embed.shape) == 4, "Four dimensions expected"
- assert len(new_size) == 2, "New shape should only be hw"
- old_size = patch_embed.shape[-2:]
- if tuple(old_size) == tuple(new_size):
- return patch_embed
- if verbose:
- _logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
- def resize(x_np, _new_size):
- x_tf = torch.Tensor(x_np)[None, None, ...]
- x_upsampled = F.interpolate(
- x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
- return x_upsampled
- def get_resize_mat(_old_size, _new_size):
- mat = []
- for i in range(np.prod(_old_size)):
- basis_vec = np.zeros(_old_size)
- basis_vec[np.unravel_index(i, _old_size)] = 1.
- mat.append(resize(basis_vec, _new_size).reshape(-1))
- return np.stack(mat).T
- resize_mat = get_resize_mat(old_size, new_size)
- resize_mat_pinv = torch.tensor(np.linalg.pinv(resize_mat.T), device=patch_embed.device)
- def resample_kernel(kernel):
- resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
- return resampled_kernel.reshape(new_size)
- v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
- orig_dtype = patch_embed.dtype
- patch_embed = patch_embed.float()
- patch_embed = v_resample_kernel(patch_embed)
- patch_embed = patch_embed.to(orig_dtype)
- return patch_embed
- DTYPE_INTERMEDIATE = torch.float32
- def _compute_resize_matrix(
- old_size: Tuple[int, int],
- new_size: Tuple[int, int],
- interpolation: str,
- antialias: bool,
- device: torch.device,
- dtype: torch.dtype = DTYPE_INTERMEDIATE
- ) -> torch.Tensor:
- """Computes the resize matrix basis vectors and interpolates them to new_size."""
- old_h, old_w = old_size
- new_h, new_w = new_size
- old_total = old_h * old_w
- new_total = new_h * new_w
- eye_matrix = torch.eye(old_total, device=device, dtype=dtype)
- basis_vectors_batch = eye_matrix.reshape(old_total, 1, old_h, old_w)
- resized_basis_vectors_batch = F.interpolate(
- basis_vectors_batch,
- size=new_size,
- mode=interpolation,
- antialias=antialias,
- align_corners=False
- ) # Output shape: (old_total, 1, new_h, new_w)
- resize_matrix = resized_basis_vectors_batch.squeeze(1).permute(1, 2, 0).reshape(new_total, old_total)
- return resize_matrix # Shape: (new_total, old_total)
- def _apply_resampling(
- patch_embed: torch.Tensor,
- pinv_matrix: torch.Tensor,
- new_size_tuple: Tuple[int, int],
- orig_dtype: torch.dtype,
- intermediate_dtype: torch.dtype = DTYPE_INTERMEDIATE
- ) -> torch.Tensor:
- """ Simplified resampling w/o vmap use.
- As proposed by https://github.com/stas-sl
- """
- c_out, c_in, *_ = patch_embed.shape
- patch_embed = patch_embed.reshape(c_out, c_in, -1).to(dtype=intermediate_dtype)
- pinv_matrix = pinv_matrix.to(dtype=intermediate_dtype)
- resampled_patch_embed = patch_embed @ pinv_matrix # (C_out, C_in, P_old * P_old) @ (P_old * P_old, P_new * P_new)
- resampled_patch_embed = resampled_patch_embed.reshape(c_out, c_in, *new_size_tuple).to(dtype=orig_dtype)
- return resampled_patch_embed
- def resample_patch_embed(
- patch_embed: torch.Tensor,
- new_size: List[int],
- interpolation: str = 'bicubic',
- antialias: bool = True,
- verbose: bool = False,
- ):
- """ Standalone function (computes matrix on each call). """
- assert len(patch_embed.shape) == 4, "Input tensor should be 4D (out_ch, in_ch, h, w)"
- assert len(new_size) == 2, "New shape should only be hw (height, width)"
- old_size_tuple: Tuple[int, int] = tuple(patch_embed.shape[-2:])
- new_size_tuple: Tuple[int, int] = tuple(new_size)
- if old_size_tuple == new_size_tuple:
- return patch_embed
- device = patch_embed.device
- orig_dtype = patch_embed.dtype
- resize_mat = _compute_resize_matrix(
- old_size_tuple, new_size_tuple, interpolation, antialias, device, DTYPE_INTERMEDIATE
- )
- pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling
- resampled_patch_embed = _apply_resampling(
- patch_embed, pinv_matrix, new_size_tuple, orig_dtype, DTYPE_INTERMEDIATE
- )
- return resampled_patch_embed
- class PatchEmbedResamplerFixedOrigSize(nn.Module):
- """
- Resample patch embedding weights from a fixed original size,
- caching the pseudoinverse matrix based on the target size.
- """
- def __init__(
- self,
- orig_size: Tuple[int, int],
- interpolation: str = 'bicubic',
- antialias: bool = True
- ):
- """
- Args:
- orig_size (Tuple[int, int]): The expected original (height, width) of input patch_embed tensors.
- interpolation (str): Interpolation mode.
- antialias (bool): Use anti-aliasing filter in resize.
- """
- super().__init__()
- assert isinstance(orig_size, tuple) and len(orig_size) == 2, \
- "`orig_size` must be a tuple of (height, width)"
- self.orig_size = orig_size # expected original size
- self.interpolation = interpolation
- self.antialias = antialias
- # Cache map key is the target new_size tuple
- self._pinv_cache_map: Dict[Tuple[int, int], str] = {}
- def _get_or_create_pinv_matrix(
- self,
- new_size: Tuple[int, int],
- device: torch.device,
- dtype: torch.dtype = DTYPE_INTERMEDIATE
- ) -> torch.Tensor:
- """Retrieves the cached pinv matrix or computes and caches it for the given new_size."""
- cache_key = new_size
- buffer_name = self._pinv_cache_map.get(cache_key)
- if buffer_name and hasattr(self, buffer_name):
- pinv_matrix = getattr(self, buffer_name)
- if pinv_matrix.device == device and pinv_matrix.dtype == dtype:
- return pinv_matrix
- # Calculate the matrix if not cached or needs update
- resize_mat = _compute_resize_matrix(
- self.orig_size, new_size, self.interpolation, self.antialias, device, dtype
- )
- pinv_matrix = torch.linalg.pinv(resize_mat) # Calculates the pseudoinverse matrix used for resampling
- # Cache using register_buffer
- buffer_name = f"pinv_{new_size[0]}x{new_size[1]}"
- if hasattr(self, buffer_name):
- delattr(self, buffer_name)
- self.register_buffer(buffer_name, pinv_matrix)
- self._pinv_cache_map[cache_key] = buffer_name # Map new_size key to buffer name
- return pinv_matrix
- def forward(self, patch_embed: torch.Tensor, new_size: List[int]) -> torch.Tensor:
- """ Resamples the patch embedding weights to new_size.
- Args:
- patch_embed (torch.Tensor): Original weights (out_ch, in_ch, H_orig, W_orig).
- new_size (List[int]): Target [height, width].
- Returns:
- torch.Tensor: Resampled weights.
- """
- assert len(patch_embed.shape) == 4
- assert len(new_size) == 2
- # Input Validation
- input_size = tuple(patch_embed.shape[-2:])
- assert input_size == self.orig_size, \
- f"Input patch_embed spatial size {input_size} does not match " \
- f"module's expected original size {self.orig_size}"
- new_size_tuple: Tuple[int, int] = tuple(new_size)
- # Check no-op case against self.orig_size
- if self.orig_size == new_size_tuple:
- return patch_embed
- device = patch_embed.device
- orig_dtype = patch_embed.dtype
- # Get or compute the required pseudoinverse matrix
- pinv_matrix = self._get_or_create_pinv_matrix(new_size_tuple, device)
- # Apply the resampling
- resampled_patch_embed = _apply_resampling(patch_embed, pinv_matrix, new_size_tuple, orig_dtype)
- return resampled_patch_embed
- class PatchEmbedInterpolator(nn.Module):
- """Dynamically interpolates patch embedding weights for variable patch sizes.
- This module wraps patch embedding weight resampling functionality to support
- on-the-fly patch size variation during training. It handles both Conv2d and
- Linear patch embeddings.
- Args:
- base_patch_size: The original patch size the model was initialized with
- in_chans: Number of input channels
- embed_dim: Embedding dimension
- interpolation: Interpolation mode for resampling
- antialias: Whether to use antialiasing during interpolation
- """
- def __init__(
- self,
- base_patch_size: Tuple[int, int],
- in_chans: int = 3,
- embed_dim: int = 768,
- interpolation: str = 'bicubic',
- antialias: bool = True,
- ):
- super().__init__()
- self.base_patch_size = base_patch_size
- self.in_chans = in_chans
- self.embed_dim = embed_dim
- self.interpolation = interpolation
- self.antialias = antialias
- def resample_linear_weight(
- self,
- weight: torch.Tensor,
- target_patch_size: Tuple[int, int],
- ) -> torch.Tensor:
- """Resample linear patch embedding weights for a new patch size.
- Args:
- weight: Linear weight tensor of shape [embed_dim, patch_h * patch_w * in_chans]
- target_patch_size: Target (patch_h, patch_w) to resample to
- Returns:
- Resampled weight tensor
- """
- if target_patch_size == self.base_patch_size:
- return weight
- embed_dim = weight.shape[0]
- base_ph, base_pw = self.base_patch_size
- target_ph, target_pw = target_patch_size
- # Reshape linear weight to conv2d format
- # [embed_dim, ph*pw*C] -> [embed_dim, C, ph, pw]
- weight_conv = weight.reshape(embed_dim, base_ph, base_pw, self.in_chans)
- weight_conv = weight_conv.permute(0, 3, 1, 2)
- # Resample using existing function
- weight_conv_resampled = resample_patch_embed(
- weight_conv,
- new_size=[target_ph, target_pw],
- interpolation=self.interpolation,
- antialias=self.antialias,
- verbose=False,
- )
- # Reshape back to linear format
- # [embed_dim, C, ph, pw] -> [embed_dim, ph*pw*C]
- weight_resampled = weight_conv_resampled.permute(0, 2, 3, 1)
- weight_resampled = weight_resampled.reshape(embed_dim, -1)
- return weight_resampled
- def resample_conv_weight(
- self,
- weight: torch.Tensor,
- target_patch_size: Tuple[int, int],
- ) -> torch.Tensor:
- """Resample conv2d patch embedding weights for a new patch size.
- Args:
- weight: Conv2d weight tensor of shape [embed_dim, in_chans, patch_h, patch_w]
- target_patch_size: Target (patch_h, patch_w) to resample to
- Returns:
- Resampled weight tensor
- """
- if target_patch_size == self.base_patch_size:
- return weight
- # Resample using existing function
- weight_resampled = resample_patch_embed(
- weight,
- new_size=list(target_patch_size),
- interpolation=self.interpolation,
- antialias=self.antialias,
- verbose=False,
- )
- return weight_resampled
- def forward(
- self,
- patches: torch.Tensor,
- proj_weight: torch.Tensor,
- proj_bias: Optional[torch.Tensor] = None,
- patch_size: Optional[Tuple[int, int]] = None,
- is_linear: bool = True,
- ) -> torch.Tensor:
- """Apply patch embedding with dynamic weight resampling.
- Args:
- patches: Input patches
- - For linear mode with resampling: [B, N, Ph, Pw, C]
- - For linear mode without resampling: [B, N, Ph*Pw*C]
- - For conv mode: [B, C, H, W]
- proj_weight: Original projection weight
- proj_bias: Optional projection bias
- patch_size: Current patch size (if None, uses base_patch_size)
- is_linear: Whether using linear (True) or conv2d (False) projection
- Returns:
- Embedded patches
- """
- if patch_size is None:
- patch_size = self.base_patch_size
- if is_linear:
- if patch_size != self.base_patch_size:
- # Need to resample - expects unflattened patches
- assert patches.ndim == 5, "Patches must be [B, N, Ph, Pw, C] for resampling"
- B, N, Ph, Pw, C = patches.shape
- # Resample the weight
- weight_resampled = self.resample_linear_weight(proj_weight, patch_size)
- # Flatten patches and apply linear projection
- patches_flat = patches.reshape(B, N, -1)
- output = torch.nn.functional.linear(patches_flat, weight_resampled, proj_bias)
- else:
- # No resampling needed, patches can be pre-flattened
- if patches.ndim == 5:
- B, N, Ph, Pw, C = patches.shape
- patches = patches.reshape(B, N, -1)
- output = torch.nn.functional.linear(patches, proj_weight, proj_bias)
- else:
- # Conv mode
- if patch_size != self.base_patch_size:
- weight_resampled = self.resample_conv_weight(proj_weight, patch_size)
- output = torch.nn.functional.conv2d(
- patches, weight_resampled, proj_bias,
- stride=patch_size, padding=0
- )
- else:
- output = torch.nn.functional.conv2d(
- patches, proj_weight, proj_bias,
- stride=patch_size, padding=0
- )
- return output
- # def divs(n, m=None):
- # m = m or n // 2
- # if m == 1:
- # return [1]
- # if n % m == 0:
- # return [m] + divs(n, m - 1)
- # return divs(n, m - 1)
- #
- #
- # class FlexiPatchEmbed(nn.Module):
- # """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
- # FIXME WIP
- # """
- # def __init__(
- # self,
- # img_size=240,
- # patch_size=16,
- # in_chans=3,
- # embed_dim=768,
- # base_img_size=240,
- # base_patch_size=32,
- # norm_layer=None,
- # flatten=True,
- # bias=True,
- # ):
- # super().__init__()
- # self.img_size = to_2tuple(img_size)
- # self.patch_size = to_2tuple(patch_size)
- # self.num_patches = 0
- #
- # # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
- # self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
- #
- # self.base_img_size = to_2tuple(base_img_size)
- # self.base_patch_size = to_2tuple(base_patch_size)
- # self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
- # self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
- #
- # self.flatten = flatten
- # self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
- # self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
- #
- # def forward(self, x):
- # B, C, H, W = x.shape
- #
- # if self.patch_size == self.base_patch_size:
- # weight = self.proj.weight
- # else:
- # weight = resample_patch_embed(self.proj.weight, self.patch_size)
- # patch_size = self.patch_size
- # x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
- # if self.flatten:
- # x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
- # x = self.norm(x)
- # return x
|