| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- import torch.nn as nn
- from torch.nn.init import trunc_normal_
- from torch.nn.utils import weight_norm
- class DINOHead(nn.Module):
- def __init__(
- self,
- in_dim,
- out_dim,
- use_bn=False,
- nlayers=3,
- hidden_dim=2048,
- bottleneck_dim=256,
- mlp_bias=True,
- ):
- super().__init__()
- nlayers = max(nlayers, 1)
- self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
- self.apply(self._init_weights)
- self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
- self.last_layer.weight_g.data.fill_(1)
- def _init_weights(self, m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- x = self.mlp(x)
- eps = 1e-6 if x.dtype == torch.float16 else 1e-12
- x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
- x = self.last_layer(x)
- return x
- def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
- if nlayers == 1:
- return nn.Linear(in_dim, bottleneck_dim, bias=bias)
- else:
- layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
- if use_bn:
- layers.append(nn.BatchNorm1d(hidden_dim))
- layers.append(nn.GELU())
- for _ in range(nlayers - 2):
- layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
- if use_bn:
- layers.append(nn.BatchNorm1d(hidden_dim))
- layers.append(nn.GELU())
- layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
- return nn.Sequential(*layers)
|