dino_head.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. import torch.nn as nn
  8. from torch.nn.init import trunc_normal_
  9. from torch.nn.utils import weight_norm
  10. class DINOHead(nn.Module):
  11. def __init__(
  12. self,
  13. in_dim,
  14. out_dim,
  15. use_bn=False,
  16. nlayers=3,
  17. hidden_dim=2048,
  18. bottleneck_dim=256,
  19. mlp_bias=True,
  20. ):
  21. super().__init__()
  22. nlayers = max(nlayers, 1)
  23. self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
  24. self.apply(self._init_weights)
  25. self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
  26. self.last_layer.weight_g.data.fill_(1)
  27. def _init_weights(self, m):
  28. if isinstance(m, nn.Linear):
  29. trunc_normal_(m.weight, std=0.02)
  30. if isinstance(m, nn.Linear) and m.bias is not None:
  31. nn.init.constant_(m.bias, 0)
  32. def forward(self, x):
  33. x = self.mlp(x)
  34. eps = 1e-6 if x.dtype == torch.float16 else 1e-12
  35. x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
  36. x = self.last_layer(x)
  37. return x
  38. def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
  39. if nlayers == 1:
  40. return nn.Linear(in_dim, bottleneck_dim, bias=bias)
  41. else:
  42. layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
  43. if use_bn:
  44. layers.append(nn.BatchNorm1d(hidden_dim))
  45. layers.append(nn.GELU())
  46. for _ in range(nlayers - 2):
  47. layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
  48. if use_bn:
  49. layers.append(nn.BatchNorm1d(hidden_dim))
  50. layers.append(nn.GELU())
  51. layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
  52. return nn.Sequential(*layers)