block.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # References:
  7. # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
  8. # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
  9. import logging
  10. from typing import Callable, List, Any, Tuple, Dict
  11. import torch
  12. from torch import nn, Tensor
  13. from .attention import Attention, MemEffAttention
  14. from .drop_path import DropPath
  15. from .layer_scale import LayerScale
  16. from .mlp import Mlp
  17. logger = logging.getLogger("dinov2")
  18. try:
  19. from xformers.ops import fmha
  20. from xformers.ops import scaled_index_add, index_select_cat
  21. XFORMERS_AVAILABLE = True
  22. except ImportError:
  23. # logger.warning("xFormers not available")
  24. XFORMERS_AVAILABLE = False
  25. class Block(nn.Module):
  26. def __init__(
  27. self,
  28. dim: int,
  29. num_heads: int,
  30. mlp_ratio: float = 4.0,
  31. qkv_bias: bool = False,
  32. proj_bias: bool = True,
  33. ffn_bias: bool = True,
  34. drop: float = 0.0,
  35. attn_drop: float = 0.0,
  36. init_values=None,
  37. drop_path: float = 0.0,
  38. act_layer: Callable[..., nn.Module] = nn.GELU,
  39. norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
  40. attn_class: Callable[..., nn.Module] = Attention,
  41. ffn_layer: Callable[..., nn.Module] = Mlp,
  42. ) -> None:
  43. super().__init__()
  44. # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
  45. self.norm1 = norm_layer(dim)
  46. self.attn = attn_class(
  47. dim,
  48. num_heads=num_heads,
  49. qkv_bias=qkv_bias,
  50. proj_bias=proj_bias,
  51. attn_drop=attn_drop,
  52. proj_drop=drop,
  53. )
  54. self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
  55. self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  56. self.norm2 = norm_layer(dim)
  57. mlp_hidden_dim = int(dim * mlp_ratio)
  58. self.mlp = ffn_layer(
  59. in_features=dim,
  60. hidden_features=mlp_hidden_dim,
  61. act_layer=act_layer,
  62. drop=drop,
  63. bias=ffn_bias,
  64. )
  65. self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
  66. self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
  67. self.sample_drop_ratio = drop_path
  68. def forward(self, x: Tensor) -> Tensor:
  69. def attn_residual_func(x: Tensor) -> Tensor:
  70. return self.ls1(self.attn(self.norm1(x)))
  71. def ffn_residual_func(x: Tensor) -> Tensor:
  72. return self.ls2(self.mlp(self.norm2(x)))
  73. if self.training and self.sample_drop_ratio > 0.1:
  74. # the overhead is compensated only for a drop path rate larger than 0.1
  75. x = drop_add_residual_stochastic_depth(
  76. x,
  77. residual_func=attn_residual_func,
  78. sample_drop_ratio=self.sample_drop_ratio,
  79. )
  80. x = drop_add_residual_stochastic_depth(
  81. x,
  82. residual_func=ffn_residual_func,
  83. sample_drop_ratio=self.sample_drop_ratio,
  84. )
  85. elif self.training and self.sample_drop_ratio > 0.0:
  86. x = x + self.drop_path1(attn_residual_func(x))
  87. x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
  88. else:
  89. x = x + attn_residual_func(x)
  90. x = x + ffn_residual_func(x)
  91. return x
  92. def drop_add_residual_stochastic_depth(
  93. x: Tensor,
  94. residual_func: Callable[[Tensor], Tensor],
  95. sample_drop_ratio: float = 0.0,
  96. ) -> Tensor:
  97. # 1) extract subset using permutation
  98. b, n, d = x.shape
  99. sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
  100. brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
  101. x_subset = x[brange]
  102. # 2) apply residual_func to get residual
  103. residual = residual_func(x_subset)
  104. x_flat = x.flatten(1)
  105. residual = residual.flatten(1)
  106. residual_scale_factor = b / sample_subset_size
  107. # 3) add the residual
  108. x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
  109. return x_plus_residual.view_as(x)
  110. def get_branges_scales(x, sample_drop_ratio=0.0):
  111. b, n, d = x.shape
  112. sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
  113. brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
  114. residual_scale_factor = b / sample_subset_size
  115. return brange, residual_scale_factor
  116. def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
  117. if scaling_vector is None:
  118. x_flat = x.flatten(1)
  119. residual = residual.flatten(1)
  120. x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
  121. else:
  122. x_plus_residual = scaled_index_add(
  123. x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
  124. )
  125. return x_plus_residual
  126. attn_bias_cache: Dict[Tuple, Any] = {}
  127. def get_attn_bias_and_cat(x_list, branges=None):
  128. """
  129. this will perform the index select, cat the tensors, and provide the attn_bias from cache
  130. """
  131. batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
  132. all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
  133. if all_shapes not in attn_bias_cache.keys():
  134. seqlens = []
  135. for b, x in zip(batch_sizes, x_list):
  136. for _ in range(b):
  137. seqlens.append(x.shape[1])
  138. attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
  139. attn_bias._batch_sizes = batch_sizes
  140. attn_bias_cache[all_shapes] = attn_bias
  141. if branges is not None:
  142. cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
  143. else:
  144. tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
  145. cat_tensors = torch.cat(tensors_bs1, dim=1)
  146. return attn_bias_cache[all_shapes], cat_tensors
  147. def drop_add_residual_stochastic_depth_list(
  148. x_list: List[Tensor],
  149. residual_func: Callable[[Tensor, Any], Tensor],
  150. sample_drop_ratio: float = 0.0,
  151. scaling_vector=None,
  152. ) -> Tensor:
  153. # 1) generate random set of indices for dropping samples in the batch
  154. branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
  155. branges = [s[0] for s in branges_scales]
  156. residual_scale_factors = [s[1] for s in branges_scales]
  157. # 2) get attention bias and index+concat the tensors
  158. attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
  159. # 3) apply residual_func to get residual, and split the result
  160. residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
  161. outputs = []
  162. for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
  163. outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
  164. return outputs
  165. class NestedTensorBlock(Block):
  166. def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
  167. """
  168. x_list contains a list of tensors to nest together and run
  169. """
  170. assert isinstance(self.attn, MemEffAttention)
  171. if self.training and self.sample_drop_ratio > 0.0:
  172. def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
  173. return self.attn(self.norm1(x), attn_bias=attn_bias)
  174. def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
  175. return self.mlp(self.norm2(x))
  176. x_list = drop_add_residual_stochastic_depth_list(
  177. x_list,
  178. residual_func=attn_residual_func,
  179. sample_drop_ratio=self.sample_drop_ratio,
  180. scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
  181. )
  182. x_list = drop_add_residual_stochastic_depth_list(
  183. x_list,
  184. residual_func=ffn_residual_func,
  185. sample_drop_ratio=self.sample_drop_ratio,
  186. scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
  187. )
  188. return x_list
  189. else:
  190. def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
  191. return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
  192. def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
  193. return self.ls2(self.mlp(self.norm2(x)))
  194. attn_bias, x = get_attn_bias_and_cat(x_list)
  195. x = x + attn_residual_func(x, attn_bias=attn_bias)
  196. x = x + ffn_residual_func(x)
  197. return attn_bias.split(x)
  198. def forward(self, x_or_x_list):
  199. if isinstance(x_or_x_list, Tensor):
  200. return super().forward(x_or_x_list)
  201. elif isinstance(x_or_x_list, list):
  202. assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
  203. return self.forward_nested(x_or_x_list)
  204. else:
  205. raise AssertionError