Browse Source

Accelerate RoMa Inference (#147)

* pad to multiples of 8

* only extract feature for unique images

* compile args and dont default pad

---------

Co-authored-by: Johan Edstedt <johan.edstedt@liu.se>
Jiang Xudong 3 months ago
parent
commit
77f8d68803

+ 1 - 0
experiments/eval_roma_outdoor.py

@@ -49,6 +49,7 @@ def test_hpatches(model, name):
 
 if __name__ == "__main__":
     from romatch import roma_outdoor
+
     device = "cuda"
     model = roma_outdoor(device = device, coarse_res = 672, upsample_res = 1344)
     experiment_name = "roma_latest"

+ 17 - 2
romatch/models/matcher.py

@@ -172,6 +172,11 @@ class ConvRefiner(nn.Module):
                 d = torch.cat((x, x_hat), dim=1)
             if self.concat_logits:
                 d = torch.cat((d, logits), dim=1)
+            # pad d if needed
+            channel_d = d.shape[1]
+            channel_block1 = self.block1[0].in_channels
+            if channel_d != channel_block1:
+                d = F.pad(d, (0, 0, 0, 0, 0, channel_block1 - channel_d))
             d = self.block1(d)
             d = self.hidden_blocks(d)
         d = self.out_conv(d.float())
@@ -583,6 +588,16 @@ class RegressionMatcher(nn.Module):
             return self.upsample_res
 
     def extract_backbone_features(self, batch, batched=True, upsample=False):
+        if 'unique_images' in batch:
+            unique_images = batch['unique_images']
+            im_AB_idx = batch['im_AB_idx']
+            feature_pyramid0 = self.encoder(unique_images, upsample=upsample)
+            feature_pyramid = {
+                scale: feature_pyramid0[scale][im_AB_idx]
+                for scale in feature_pyramid0
+            }
+            return feature_pyramid
+            
         x_q = batch["im_A"]
         x_s = batch["im_B"]
         if batched:
@@ -831,7 +846,7 @@ class RegressionMatcher(nn.Module):
         if symmetric:
             corresps = self.forward_symmetric(batch, scale_factor=scale_factor)
         else:
-            corresps = self.forward(batch, batched=True, scale_factor=scale_factor)
+            corresps = self(batch, batched=True, scale_factor=scale_factor)
 
         if self.upsample_preds:
             hs, ws = self.upsample_res
@@ -884,7 +899,7 @@ class RegressionMatcher(nn.Module):
                     batch, upsample=True, batched=True, scale_factor=scale_factor
                 )
             else:
-                corresps = self.forward(
+                corresps = self(
                     batch, batched=True, upsample=True, scale_factor=scale_factor
                 )
 

+ 19 - 3
romatch/models/model_zoo/__init__.py

@@ -1,6 +1,6 @@
 from typing import Union
 import torch
-from .roma_models import roma_model, tiny_roma_v1_model
+from .roma_models import roma_model,roma_model_pad, tiny_roma_v1_model
 
 
 weight_urls = {
@@ -38,7 +38,12 @@ def roma_outdoor(
     symmetric=True,
     use_custom_corr=True,
     upsample_preds=True,
+    with_padding=False,
+    do_compile=False,
 ):
+    if torch.get_float32_matmul_precision() != "highest":
+        raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
+
     if weights is None:
         weights = torch.hub.load_state_dict_from_url(
             weight_urls["romatch"]["outdoor"], map_location=device
@@ -47,7 +52,8 @@ def roma_outdoor(
         dinov2_weights = torch.hub.load_state_dict_from_url(
             weight_urls["dinov2"], map_location=device
         )
-    model = roma_model(
+    model_init = roma_model if not with_padding else roma_model_pad
+    model = model_init(
         resolution=coarse_res,
         upsample_preds=upsample_preds,
         weights=weights,
@@ -58,6 +64,8 @@ def roma_outdoor(
         use_custom_corr=use_custom_corr,
         upsample_res=upsample_res,
     )
+    if do_compile:
+        model.compile()
     return model
 
 
@@ -71,7 +79,12 @@ def roma_indoor(
     symmetric=True,
     use_custom_corr=True,
     upsample_preds=True,
+    with_padding=False,
+    do_compile=False,
 ):
+    if torch.get_float32_matmul_precision() != "highest":
+        raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35")
+
     if weights is None:
         weights = torch.hub.load_state_dict_from_url(
             weight_urls["romatch"]["indoor"], map_location=device
@@ -80,7 +93,8 @@ def roma_indoor(
         dinov2_weights = torch.hub.load_state_dict_from_url(
             weight_urls["dinov2"], map_location=device
         )
-    model = roma_model(
+    model_init = roma_model if not with_padding else roma_model_pad
+    model = model_init(
         resolution=coarse_res,
         upsample_preds=upsample_preds,
         weights=weights,
@@ -91,4 +105,6 @@ def roma_indoor(
         use_custom_corr=use_custom_corr,
         upsample_res=upsample_res,
     )
+    if do_compile:
+        model.compile()
     return model

+ 194 - 0
romatch/models/model_zoo/roma_models.py

@@ -28,6 +28,200 @@ def tiny_roma_v1_model(
         model.load_state_dict(weights)
     return model
 
+def pad_refiner_state_dict(state_dict_old,state_dict_pad):
+    for key in state_dict_pad.keys():
+        if key.startswith('decoder.conv_refiner'):
+            param = state_dict_old[key]
+            shape_old = param.shape
+            shape_pad = state_dict_pad[key].shape
+            if shape_old != shape_pad:
+                new_param = torch.zeros(shape_pad, device=param.device, dtype=param.dtype)
+                slices = tuple(slice(0, s) for s in shape_old)
+                new_param[slices] = param
+                state_dict_old[key] = new_param
+    return state_dict_old
+
+def roma_model_pad(
+    resolution,
+    upsample_preds,
+    device=None,
+    weights=None,
+    dinov2_weights=None,
+    amp_dtype: torch.dtype = torch.float16,
+    use_custom_corr=True,
+    symmetric=True,
+    upsample_res=None,
+    sample_thresh=0.05,
+    sample_mode="threshold_balanced",
+    attenuate_cert = True,
+    refiner_channels= [1384, 1144, 576, 144, 24],
+    **kwargs,
+):
+    if sys.platform != "linux":
+        use_custom_corr = False
+        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
+    if isinstance(resolution, int):
+        resolution = (resolution, resolution)
+    if isinstance(upsample_res, int):
+        upsample_res = (upsample_res, upsample_res)
+
+    if str(device) == "cpu":
+        amp_dtype = torch.float32
+
+    assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+
+    logger.info(
+        f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
+    )
+
+    if sys.platform != "linux":
+        use_custom_corr = False
+        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
+    warnings.filterwarnings(
+        "ignore", category=UserWarning, message="TypedStorage is deprecated"
+    )
+    gp_dim = 512
+    feat_dim = 512
+    decoder_dim = gp_dim + feat_dim
+    cls_to_coord_res = 64
+    coordinate_decoder = TransformerDecoder(
+        nn.Sequential(
+            *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
+        ),
+        decoder_dim,
+        cls_to_coord_res**2 + 1,
+        is_classifier=True,
+        amp=True,
+        pos_enc=False,
+    )
+    dw = True
+    hidden_blocks = 8
+    kernel_size = 5
+    displacement_emb = "linear"
+    disable_local_corr_grad = True
+    partial_conv_refiner = partial(
+        ConvRefiner,
+        kernel_size=kernel_size,
+        dw=dw,
+        hidden_blocks=hidden_blocks,
+        displacement_emb=displacement_emb,
+        corr_in_other=True,
+        amp=True,
+        disable_local_corr_grad=disable_local_corr_grad,
+        bn_momentum=0.01,
+        use_custom_corr=use_custom_corr,
+    )
+
+    conv_refiner = nn.ModuleDict(
+        {
+            "16": partial_conv_refiner(
+                refiner_channels[0],
+                refiner_channels[0],
+                2 + 1,
+                displacement_emb_dim=128,
+                local_corr_radius=7,
+            ),
+            "8": partial_conv_refiner(
+                refiner_channels[1],
+                refiner_channels[1],
+                2 + 1,
+                displacement_emb_dim=64,
+                local_corr_radius=3,
+            ),
+            "4": partial_conv_refiner(
+                refiner_channels[2],
+                refiner_channels[2],
+                2 + 1,
+                displacement_emb_dim=32,
+                local_corr_radius=2,
+            ),
+            "2": partial_conv_refiner(
+                refiner_channels[3],
+                refiner_channels[3],
+                2 + 1,
+                displacement_emb_dim=16,
+            ),
+            "1": partial_conv_refiner(
+                refiner_channels[4],
+                refiner_channels[4],
+                2 + 1,
+                displacement_emb_dim=6,
+            ),
+        }
+    )
+    kernel_temperature = 0.2
+    learn_temperature = False
+    no_cov = True
+    kernel = CosKernel
+    only_attention = False
+    basis = "fourier"
+    gp16 = GP(
+        kernel,
+        T=kernel_temperature,
+        learn_temperature=learn_temperature,
+        only_attention=only_attention,
+        gp_dim=gp_dim,
+        basis=basis,
+        no_cov=no_cov,
+    )
+    gps = nn.ModuleDict({"16": gp16})
+    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+    proj = nn.ModuleDict(
+        {
+            "16": proj16,
+            "8": proj8,
+            "4": proj4,
+            "2": proj2,
+            "1": proj1,
+        }
+    )
+    displacement_dropout_p = 0.0
+    gm_warp_dropout_p = 0.0
+    decoder = Decoder(
+        coordinate_decoder,
+        gps,
+        proj,
+        conv_refiner,
+        detach=True,
+        scales=["16", "8", "4", "2", "1"],
+        displacement_dropout_p=displacement_dropout_p,
+        gm_warp_dropout_p=gm_warp_dropout_p,
+    )
+
+    encoder = CNNandDinov2(
+        cnn_kwargs=dict(pretrained=False, amp=True),
+        amp=True,
+        dinov2_weights=dinov2_weights,
+        amp_dtype=amp_dtype,
+    )
+    h, w = resolution
+    
+    matcher = RegressionMatcher(
+        encoder,
+        decoder,
+        h=h,
+        w=w,
+        upsample_preds=upsample_preds,
+        upsample_res=upsample_res,
+        symmetric=symmetric,
+        attenuate_cert=attenuate_cert,
+        sample_mode=sample_mode,
+        sample_thresh=sample_thresh,
+        **kwargs,
+    ).to(device)
+    if weights is not None:
+        state_dict_pad = matcher.state_dict()
+        weights = pad_refiner_state_dict(weights,state_dict_pad)
+        del state_dict_pad
+
+    matcher.load_state_dict(weights)
+    return matcher
+
 
 def roma_model(
     resolution,

+ 1 - 1
tests/test_mega_dense.py

@@ -9,7 +9,7 @@ def test_mega_dense(model, name):
 if __name__ == "__main__":
     from romatch import roma_outdoor
     device = "cuda"
-    model = roma_outdoor(device = device, coarse_res = 560, use_custom_corr=True, symmetric = False, upsample_preds = False)
+    model = roma_outdoor(device = device, coarse_res = 560, use_custom_corr=True, symmetric = False, upsample_preds = False, do_compile = True, with_padding = True)
     experiment_name = "roma_latest"
     results = test_mega_dense(model, experiment_name)
     print(results)