dinov2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # References:
  7. # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
  8. # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
  9. from functools import partial
  10. import math
  11. import logging
  12. from typing import Sequence, Tuple, Union, Callable
  13. import torch
  14. import torch.nn as nn
  15. import torch.utils.checkpoint
  16. from torch.nn.init import trunc_normal_
  17. from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
  18. def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
  19. if not depth_first and include_root:
  20. fn(module=module, name=name)
  21. for child_name, child_module in module.named_children():
  22. child_name = ".".join((name, child_name)) if name else child_name
  23. named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
  24. if depth_first and include_root:
  25. fn(module=module, name=name)
  26. return module
  27. class BlockChunk(nn.ModuleList):
  28. def forward(self, x):
  29. for b in self:
  30. x = b(x)
  31. return x
  32. class DinoVisionTransformer(nn.Module):
  33. def __init__(
  34. self,
  35. img_size=224,
  36. patch_size=16,
  37. in_chans=3,
  38. embed_dim=768,
  39. depth=12,
  40. num_heads=12,
  41. mlp_ratio=4.0,
  42. qkv_bias=True,
  43. ffn_bias=True,
  44. proj_bias=True,
  45. drop_path_rate=0.0,
  46. drop_path_uniform=False,
  47. init_values=None, # for layerscale: None or 0 => no layerscale
  48. embed_layer=PatchEmbed,
  49. act_layer=nn.GELU,
  50. block_fn=Block,
  51. ffn_layer="mlp",
  52. block_chunks=1,
  53. ):
  54. """
  55. Args:
  56. img_size (int, tuple): input image size
  57. patch_size (int, tuple): patch size
  58. in_chans (int): number of input channels
  59. embed_dim (int): embedding dimension
  60. depth (int): depth of transformer
  61. num_heads (int): number of attention heads
  62. mlp_ratio (int): ratio of mlp hidden dim to embedding dim
  63. qkv_bias (bool): enable bias for qkv if True
  64. proj_bias (bool): enable bias for proj in attn if True
  65. ffn_bias (bool): enable bias for ffn if True
  66. drop_path_rate (float): stochastic depth rate
  67. drop_path_uniform (bool): apply uniform drop rate across blocks
  68. weight_init (str): weight init scheme
  69. init_values (float): layer-scale init values
  70. embed_layer (nn.Module): patch embedding layer
  71. act_layer (nn.Module): MLP activation layer
  72. block_fn (nn.Module): transformer block class
  73. ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
  74. block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
  75. """
  76. super().__init__()
  77. norm_layer = partial(nn.LayerNorm, eps=1e-6)
  78. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  79. self.num_tokens = 1
  80. self.n_blocks = depth
  81. self.num_heads = num_heads
  82. self.patch_size = patch_size
  83. self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
  84. num_patches = self.patch_embed.num_patches
  85. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  86. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
  87. if drop_path_uniform is True:
  88. dpr = [drop_path_rate] * depth
  89. else:
  90. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  91. if ffn_layer == "mlp":
  92. ffn_layer = Mlp
  93. elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
  94. ffn_layer = SwiGLUFFNFused
  95. elif ffn_layer == "identity":
  96. def f(*args, **kwargs):
  97. return nn.Identity()
  98. ffn_layer = f
  99. else:
  100. raise NotImplementedError
  101. blocks_list = [
  102. block_fn(
  103. dim=embed_dim,
  104. num_heads=num_heads,
  105. mlp_ratio=mlp_ratio,
  106. qkv_bias=qkv_bias,
  107. proj_bias=proj_bias,
  108. ffn_bias=ffn_bias,
  109. drop_path=dpr[i],
  110. norm_layer=norm_layer,
  111. act_layer=act_layer,
  112. ffn_layer=ffn_layer,
  113. init_values=init_values,
  114. )
  115. for i in range(depth)
  116. ]
  117. if block_chunks > 0:
  118. self.chunked_blocks = True
  119. chunked_blocks = []
  120. chunksize = depth // block_chunks
  121. for i in range(0, depth, chunksize):
  122. # this is to keep the block index consistent if we chunk the block list
  123. chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
  124. self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
  125. else:
  126. self.chunked_blocks = False
  127. self.blocks = nn.ModuleList(blocks_list)
  128. self.norm = norm_layer(embed_dim)
  129. self.head = nn.Identity()
  130. self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
  131. self.init_weights()
  132. for param in self.parameters():
  133. param.requires_grad = False
  134. @property
  135. def device(self):
  136. return self.cls_token.device
  137. def init_weights(self):
  138. trunc_normal_(self.pos_embed, std=0.02)
  139. nn.init.normal_(self.cls_token, std=1e-6)
  140. named_apply(init_weights_vit_timm, self)
  141. def interpolate_pos_encoding(self, x, w, h):
  142. previous_dtype = x.dtype
  143. npatch = x.shape[1] - 1
  144. N = self.pos_embed.shape[1] - 1
  145. if npatch == N and w == h:
  146. return self.pos_embed
  147. pos_embed = self.pos_embed.float()
  148. class_pos_embed = pos_embed[:, 0]
  149. patch_pos_embed = pos_embed[:, 1:]
  150. dim = x.shape[-1]
  151. w0 = w // self.patch_size
  152. h0 = h // self.patch_size
  153. # we add a small number to avoid floating point error in the interpolation
  154. # see discussion at https://github.com/facebookresearch/dino/issues/8
  155. w0, h0 = w0 + 0.1, h0 + 0.1
  156. patch_pos_embed = nn.functional.interpolate(
  157. patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
  158. scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
  159. mode="bicubic",
  160. )
  161. assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
  162. patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
  163. return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
  164. def prepare_tokens_with_masks(self, x, masks=None):
  165. B, nc, w, h = x.shape
  166. x = self.patch_embed(x)
  167. if masks is not None:
  168. x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
  169. x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
  170. x = x + self.interpolate_pos_encoding(x, w, h)
  171. return x
  172. def forward_features_list(self, x_list, masks_list):
  173. x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
  174. for blk in self.blocks:
  175. x = blk(x)
  176. all_x = x
  177. output = []
  178. for x, masks in zip(all_x, masks_list):
  179. x_norm = self.norm(x)
  180. output.append(
  181. {
  182. "x_norm_clstoken": x_norm[:, 0],
  183. "x_norm_patchtokens": x_norm[:, 1:],
  184. "x_prenorm": x,
  185. "masks": masks,
  186. }
  187. )
  188. return output
  189. def forward_features(self, x, masks=None):
  190. if isinstance(x, list):
  191. return self.forward_features_list(x, masks)
  192. x = self.prepare_tokens_with_masks(x, masks)
  193. for blk in self.blocks:
  194. x = blk(x)
  195. x_norm = self.norm(x)
  196. return {
  197. "x_norm_clstoken": x_norm[:, 0],
  198. "x_norm_patchtokens": x_norm[:, 1:],
  199. "x_prenorm": x,
  200. "masks": masks,
  201. }
  202. def _get_intermediate_layers_not_chunked(self, x, n=1):
  203. x = self.prepare_tokens_with_masks(x)
  204. # If n is an int, take the n last blocks. If it's a list, take them
  205. output, total_block_len = [], len(self.blocks)
  206. blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
  207. for i, blk in enumerate(self.blocks):
  208. x = blk(x)
  209. if i in blocks_to_take:
  210. output.append(x)
  211. assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
  212. return output
  213. def _get_intermediate_layers_chunked(self, x, n=1):
  214. x = self.prepare_tokens_with_masks(x)
  215. output, i, total_block_len = [], 0, len(self.blocks[-1])
  216. # If n is an int, take the n last blocks. If it's a list, take them
  217. blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
  218. for block_chunk in self.blocks:
  219. for blk in block_chunk[i:]: # Passing the nn.Identity()
  220. x = blk(x)
  221. if i in blocks_to_take:
  222. output.append(x)
  223. i += 1
  224. assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
  225. return output
  226. def get_intermediate_layers(
  227. self,
  228. x: torch.Tensor,
  229. n: Union[int, Sequence] = 1, # Layers or n last layers to take
  230. reshape: bool = False,
  231. return_class_token: bool = False,
  232. norm=True,
  233. ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
  234. if self.chunked_blocks:
  235. outputs = self._get_intermediate_layers_chunked(x, n)
  236. else:
  237. outputs = self._get_intermediate_layers_not_chunked(x, n)
  238. if norm:
  239. outputs = [self.norm(out) for out in outputs]
  240. class_tokens = [out[:, 0] for out in outputs]
  241. outputs = [out[:, 1:] for out in outputs]
  242. if reshape:
  243. B, _, w, h = x.shape
  244. outputs = [
  245. out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
  246. for out in outputs
  247. ]
  248. if return_class_token:
  249. return tuple(zip(outputs, class_tokens))
  250. return tuple(outputs)
  251. def forward(self, *args, is_training=False, **kwargs):
  252. ret = self.forward_features(*args, **kwargs)
  253. if is_training:
  254. return ret
  255. else:
  256. return self.head(ret["x_norm_clstoken"])
  257. def init_weights_vit_timm(module: nn.Module, name: str = ""):
  258. """ViT weight initialization, original timm impl (for reproducibility)"""
  259. if isinstance(module, nn.Linear):
  260. trunc_normal_(module.weight, std=0.02)
  261. if module.bias is not None:
  262. nn.init.zeros_(module.bias)
  263. def vit_small(patch_size=16, **kwargs):
  264. model = DinoVisionTransformer(
  265. patch_size=patch_size,
  266. embed_dim=384,
  267. depth=12,
  268. num_heads=6,
  269. mlp_ratio=4,
  270. block_fn=partial(Block, attn_class=MemEffAttention),
  271. **kwargs,
  272. )
  273. return model
  274. def vit_base(patch_size=16, **kwargs):
  275. model = DinoVisionTransformer(
  276. patch_size=patch_size,
  277. embed_dim=768,
  278. depth=12,
  279. num_heads=12,
  280. mlp_ratio=4,
  281. block_fn=partial(Block, attn_class=MemEffAttention),
  282. **kwargs,
  283. )
  284. return model
  285. def vit_large(patch_size=16, **kwargs):
  286. model = DinoVisionTransformer(
  287. patch_size=patch_size,
  288. embed_dim=1024,
  289. depth=24,
  290. num_heads=16,
  291. mlp_ratio=4,
  292. block_fn=partial(Block, attn_class=MemEffAttention),
  293. **kwargs,
  294. )
  295. return model
  296. def vit_giant2(patch_size=16, **kwargs):
  297. """
  298. Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
  299. """
  300. model = DinoVisionTransformer(
  301. patch_size=patch_size,
  302. embed_dim=1536,
  303. depth=40,
  304. num_heads=24,
  305. mlp_ratio=4,
  306. block_fn=partial(Block, attn_class=MemEffAttention),
  307. **kwargs,
  308. )
  309. return model