encoders.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import torch
  2. import torch.nn as nn
  3. import torchvision.models as tvm
  4. from romatch.utils.utils import get_autocast_params
  5. class VGG19(nn.Module):
  6. def __init__(self, pretrained=True, amp = False, amp_dtype = torch.float16) -> None:
  7. super().__init__()
  8. if pretrained:
  9. weights = tvm.vgg.VGG19_BN_Weights.IMAGENET1K_V1
  10. else:
  11. weights = None
  12. self.layers = nn.ModuleList(tvm.vgg19_bn(weights=weights).features[:40])
  13. self.amp = amp
  14. self.amp_dtype = amp_dtype
  15. def forward(self, x, **kwargs):
  16. autocast_device, autocast_enabled, autocast_dtype = get_autocast_params(x.device, self.amp, self.amp_dtype)
  17. with torch.autocast(device_type=autocast_device, enabled=autocast_enabled, dtype = autocast_dtype):
  18. feats = {}
  19. scale = 1
  20. for layer in self.layers:
  21. if isinstance(layer, nn.MaxPool2d):
  22. feats[scale] = x
  23. scale = scale*2
  24. x = layer(x)
  25. return feats
  26. class CNNandDinov2(nn.Module):
  27. def __init__(self, cnn_kwargs = None, amp = False, dinov2_weights = None, amp_dtype = torch.float16):
  28. super().__init__()
  29. if dinov2_weights is None:
  30. dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
  31. from .transformer import vit_large
  32. vit_kwargs = dict(img_size= 518,
  33. patch_size= 14,
  34. init_values = 1.0,
  35. ffn_layer = "mlp",
  36. block_chunks = 0,
  37. )
  38. dinov2_vitl14 = vit_large(**vit_kwargs).eval()
  39. dinov2_vitl14.load_state_dict(dinov2_weights)
  40. cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
  41. self.cnn = VGG19(**cnn_kwargs)
  42. self.amp = amp
  43. self.amp_dtype = amp_dtype
  44. if self.amp:
  45. dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
  46. self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
  47. def train(self, mode: bool = True):
  48. return self.cnn.train(mode)
  49. def forward(self, x, upsample = False):
  50. B,C,H,W = x.shape
  51. feature_pyramid = self.cnn(x)
  52. if not upsample:
  53. with torch.no_grad():
  54. if self.dinov2_vitl14[0].device != x.device:
  55. self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
  56. dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
  57. features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
  58. del dinov2_features_16
  59. feature_pyramid[16] = features_16
  60. return feature_pyramid