attention.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # References:
  7. # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
  8. # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
  9. import logging
  10. from torch import Tensor
  11. from torch import nn
  12. import torch
  13. logger = logging.getLogger("dinov2")
  14. try:
  15. from xformers.ops import memory_efficient_attention, unbind, fmha
  16. XFORMERS_AVAILABLE = True
  17. except ImportError:
  18. # logger.warning("xFormers not available")
  19. XFORMERS_AVAILABLE = False
  20. class Attention(nn.Module):
  21. def __init__(
  22. self,
  23. dim: int,
  24. num_heads: int = 8,
  25. qkv_bias: bool = False,
  26. proj_bias: bool = True,
  27. attn_drop: float = 0.0,
  28. proj_drop: float = 0.0,
  29. ) -> None:
  30. super().__init__()
  31. self.num_heads = num_heads
  32. head_dim = dim // num_heads
  33. self.scale = head_dim**-0.5
  34. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  35. self.attn_drop = nn.Dropout(attn_drop)
  36. self.proj = nn.Linear(dim, dim, bias=proj_bias)
  37. self.proj_drop = nn.Dropout(proj_drop)
  38. def forward(self, x: Tensor) -> Tensor:
  39. # use new pytorch native attn
  40. qkv = self.qkv(x)
  41. B, N, _ = qkv.shape
  42. C = self.qkv.in_features
  43. qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
  44. q, k, v = torch.unbind(qkv, 2)
  45. q, k, v = [t.transpose(1, 2) for t in [q, k, v]]
  46. x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
  47. x = x.transpose(1, 2).reshape([B, N, C])
  48. x = self.proj(x)
  49. x = self.proj_drop(x)
  50. return x
  51. # old code below
  52. B, N, C = x.shape
  53. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  54. q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
  55. attn = q @ k.transpose(-2, -1)
  56. attn = attn.softmax(dim=-1)
  57. attn = self.attn_drop(attn)
  58. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  59. x = self.proj(x)
  60. x = self.proj_drop(x)
  61. return x
  62. class MemEffAttention(Attention):
  63. def forward(self, x: Tensor, attn_bias=None) -> Tensor:
  64. if not XFORMERS_AVAILABLE:
  65. assert attn_bias is None, "xFormers is required for nested tensors usage"
  66. return super().forward(x)
  67. B, N, C = x.shape
  68. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
  69. q, k, v = unbind(qkv, 2)
  70. x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
  71. x = x.reshape([B, N, C])
  72. x = self.proj(x)
  73. x = self.proj_drop(x)
  74. return x