Bladeren bron

remove hardcoded .half casts and bf16

Johan Edstedt 2 jaren geleden
bovenliggende
commit
e5751d80cb

+ 1 - 1
requirements.txt

@@ -10,4 +10,4 @@ matplotlib
 h5py
 wandb
 timm
-xformers # Optional, used for memefficient attention
+#xformers # Optional, used for memefficient attention

+ 7 - 6
roma/models/encoders.py

@@ -8,7 +8,8 @@ import gc
 
 
 class ResNet50(nn.Module):
-    def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
+    def __init__(self, pretrained=False, high_res = False, weights = None, 
+                 dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
         super().__init__()
         if dilation is None:
             dilation = [False,False,False]
@@ -24,7 +25,7 @@ class ResNet50(nn.Module):
         self.freeze_bn = freeze_bn
         self.early_exit = early_exit
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
 
     def forward(self, x, **kwargs):
         with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -56,11 +57,11 @@ class ResNet50(nn.Module):
                 pass
 
 class VGG19(nn.Module):
-    def __init__(self, pretrained=False, amp = False) -> None:
+    def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
         super().__init__()
         self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
 
     def forward(self, x, **kwargs):
         with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -74,7 +75,7 @@ class VGG19(nn.Module):
             return feats
 
 class CNNandDinov2(nn.Module):
-    def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
+    def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
         super().__init__()
         if dinov2_weights is None:
             dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
@@ -94,7 +95,7 @@ class CNNandDinov2(nn.Module):
         else:
             self.cnn = VGG19(**cnn_kwargs)
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
         if self.amp:
             dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
         self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP

+ 5 - 4
roma/models/matcher.py

@@ -38,6 +38,7 @@ class ConvRefiner(nn.Module):
         sample_mode = "bilinear",
         norm_type = nn.BatchNorm2d,
         bn_momentum = 0.1,
+        amp_dtype = torch.float16,
     ):
         super().__init__()
         self.bn_momentum = bn_momentum
@@ -72,7 +73,7 @@ class ConvRefiner(nn.Module):
         self.disable_local_corr_grad = disable_local_corr_grad
         self.is_classifier = is_classifier
         self.sample_mode = sample_mode
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
         
     def create_block(
         self,
@@ -275,7 +276,7 @@ class Decoder(nn.Module):
     def __init__(
         self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
         num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
-        flow_upsample_mode = "bilinear"
+        flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
     ):
         super().__init__()
         self.embedding_decoder = embedding_decoder
@@ -297,7 +298,7 @@ class Decoder(nn.Module):
         self.displacement_dropout_p = displacement_dropout_p
         self.gm_warp_dropout_p = gm_warp_dropout_p
         self.flow_upsample_mode = flow_upsample_mode
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
         
     def get_placeholder_flow(self, b, h, w, device):
         coarse_coords = torch.meshgrid(
@@ -361,7 +362,7 @@ class Decoder(nn.Module):
             corresps[ins] = {}
             f1_s, f2_s = f1[ins], f2[ins]
             if new_scale in self.proj:
-                with torch.autocast("cuda", self.amp_dtype):
+                with torch.autocast("cuda", dtype = self.amp_dtype):
                     f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
 
             if ins in coarse_scales:

+ 2 - 2
roma/models/transformer/__init__.py

@@ -9,7 +9,7 @@ from .dinov2 import vit_large
 
 class TransformerDecoder(nn.Module):
     def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, 
-                 amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, **kwargs) -> None:
+                 amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
         super().__init__(*args, **kwargs)
         self.blocks = blocks
         self.to_out = nn.Linear(hidden_dim, out_dim)
@@ -18,7 +18,7 @@ class TransformerDecoder(nn.Module):
         self._scales = [16]
         self.is_classifier = is_classifier
         self.amp = amp
-        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
         self.pos_enc = pos_enc
         self.learned_embeddings = learned_embeddings
         if self.learned_embeddings:

+ 1 - 1
roma/utils/kde.py

@@ -2,7 +2,7 @@ import torch
 
 def kde(x, std = 0.1):
     # use a gaussian kernel to estimate density
-    x = x.half() # Do it in half precision
+    x = x.half() # Do it in half precision TODO: remove hardcoding
     scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
     density = scores.sum(dim=-1)
     return density

+ 1 - 3
roma/utils/local_correlation.py

@@ -12,8 +12,6 @@ def local_correlation(
     r = local_radius
     K = (2*r+1)**2
     B, c, h, w = feature0.size()
-    feature0 = feature0.half()
-    feature1 = feature1.half()
     corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
     if flow is None:
         # If flow is None, assume feature0 and feature1 are aligned
@@ -37,7 +35,7 @@ def local_correlation(
         ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
     for _ in range(B):
         with torch.no_grad():
-            local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2).half()
+            local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2)
             window_feature = F.grid_sample(
                 feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
             )