| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- # References:
- # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
- import logging
- from typing import Callable, List, Any, Tuple, Dict
- import torch
- from torch import nn, Tensor
- from .attention import Attention, MemEffAttention
- from .drop_path import DropPath
- from .layer_scale import LayerScale
- from .mlp import Mlp
- logger = logging.getLogger("dinov2")
- try:
- from xformers.ops import fmha
- from xformers.ops import scaled_index_add, index_select_cat
- XFORMERS_AVAILABLE = True
- except ImportError:
- # logger.warning("xFormers not available")
- XFORMERS_AVAILABLE = False
- class Block(nn.Module):
- def __init__(
- self,
- dim: int,
- num_heads: int,
- mlp_ratio: float = 4.0,
- qkv_bias: bool = False,
- proj_bias: bool = True,
- ffn_bias: bool = True,
- drop: float = 0.0,
- attn_drop: float = 0.0,
- init_values=None,
- drop_path: float = 0.0,
- act_layer: Callable[..., nn.Module] = nn.GELU,
- norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
- attn_class: Callable[..., nn.Module] = Attention,
- ffn_layer: Callable[..., nn.Module] = Mlp,
- ) -> None:
- super().__init__()
- # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
- self.norm1 = norm_layer(dim)
- self.attn = attn_class(
- dim,
- num_heads=num_heads,
- qkv_bias=qkv_bias,
- proj_bias=proj_bias,
- attn_drop=attn_drop,
- proj_drop=drop,
- )
- self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.norm2 = norm_layer(dim)
- mlp_hidden_dim = int(dim * mlp_ratio)
- self.mlp = ffn_layer(
- in_features=dim,
- hidden_features=mlp_hidden_dim,
- act_layer=act_layer,
- drop=drop,
- bias=ffn_bias,
- )
- self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
- self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
- self.sample_drop_ratio = drop_path
- def forward(self, x: Tensor) -> Tensor:
- def attn_residual_func(x: Tensor) -> Tensor:
- return self.ls1(self.attn(self.norm1(x)))
- def ffn_residual_func(x: Tensor) -> Tensor:
- return self.ls2(self.mlp(self.norm2(x)))
- if self.training and self.sample_drop_ratio > 0.1:
- # the overhead is compensated only for a drop path rate larger than 0.1
- x = drop_add_residual_stochastic_depth(
- x,
- residual_func=attn_residual_func,
- sample_drop_ratio=self.sample_drop_ratio,
- )
- x = drop_add_residual_stochastic_depth(
- x,
- residual_func=ffn_residual_func,
- sample_drop_ratio=self.sample_drop_ratio,
- )
- elif self.training and self.sample_drop_ratio > 0.0:
- x = x + self.drop_path1(attn_residual_func(x))
- x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
- else:
- x = x + attn_residual_func(x)
- x = x + ffn_residual_func(x)
- return x
- def drop_add_residual_stochastic_depth(
- x: Tensor,
- residual_func: Callable[[Tensor], Tensor],
- sample_drop_ratio: float = 0.0,
- ) -> Tensor:
- # 1) extract subset using permutation
- b, n, d = x.shape
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
- x_subset = x[brange]
- # 2) apply residual_func to get residual
- residual = residual_func(x_subset)
- x_flat = x.flatten(1)
- residual = residual.flatten(1)
- residual_scale_factor = b / sample_subset_size
- # 3) add the residual
- x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
- return x_plus_residual.view_as(x)
- def get_branges_scales(x, sample_drop_ratio=0.0):
- b, n, d = x.shape
- sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
- brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
- residual_scale_factor = b / sample_subset_size
- return brange, residual_scale_factor
- def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
- if scaling_vector is None:
- x_flat = x.flatten(1)
- residual = residual.flatten(1)
- x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
- else:
- x_plus_residual = scaled_index_add(
- x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
- )
- return x_plus_residual
- attn_bias_cache: Dict[Tuple, Any] = {}
- def get_attn_bias_and_cat(x_list, branges=None):
- """
- this will perform the index select, cat the tensors, and provide the attn_bias from cache
- """
- batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
- all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
- if all_shapes not in attn_bias_cache.keys():
- seqlens = []
- for b, x in zip(batch_sizes, x_list):
- for _ in range(b):
- seqlens.append(x.shape[1])
- attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
- attn_bias._batch_sizes = batch_sizes
- attn_bias_cache[all_shapes] = attn_bias
- if branges is not None:
- cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
- else:
- tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
- cat_tensors = torch.cat(tensors_bs1, dim=1)
- return attn_bias_cache[all_shapes], cat_tensors
- def drop_add_residual_stochastic_depth_list(
- x_list: List[Tensor],
- residual_func: Callable[[Tensor, Any], Tensor],
- sample_drop_ratio: float = 0.0,
- scaling_vector=None,
- ) -> Tensor:
- # 1) generate random set of indices for dropping samples in the batch
- branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
- branges = [s[0] for s in branges_scales]
- residual_scale_factors = [s[1] for s in branges_scales]
- # 2) get attention bias and index+concat the tensors
- attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
- # 3) apply residual_func to get residual, and split the result
- residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
- outputs = []
- for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
- outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
- return outputs
- class NestedTensorBlock(Block):
- def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
- """
- x_list contains a list of tensors to nest together and run
- """
- assert isinstance(self.attn, MemEffAttention)
- if self.training and self.sample_drop_ratio > 0.0:
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.attn(self.norm1(x), attn_bias=attn_bias)
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.mlp(self.norm2(x))
- x_list = drop_add_residual_stochastic_depth_list(
- x_list,
- residual_func=attn_residual_func,
- sample_drop_ratio=self.sample_drop_ratio,
- scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
- )
- x_list = drop_add_residual_stochastic_depth_list(
- x_list,
- residual_func=ffn_residual_func,
- sample_drop_ratio=self.sample_drop_ratio,
- scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
- )
- return x_list
- else:
- def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
- def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
- return self.ls2(self.mlp(self.norm2(x)))
- attn_bias, x = get_attn_bias_and_cat(x_list)
- x = x + attn_residual_func(x, attn_bias=attn_bias)
- x = x + ffn_residual_func(x)
- return attn_bias.split(x)
- def forward(self, x_or_x_list):
- if isinstance(x_or_x_list, Tensor):
- return super().forward(x_or_x_list)
- elif isinstance(x_or_x_list, list):
- assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
- return self.forward_nested(x_or_x_list)
- else:
- raise AssertionError
|