فهرست منبع

only fused-local-corr for linux and make matching params visible (#122)

* only fused-local-corr for linux

* fix bug in test
Johan Edstedt 7 ماه پیش
والد
کامیت
60f9a48934

+ 1 - 1
pyproject.toml

@@ -10,7 +10,7 @@ requires-python = ">=3.9"
 dependencies = [
     "albumentations",
     "einops",
-    "fused-local-corr>=0.2.2",
+    "fused-local-corr>=0.2.2 ; sys_platform == 'linux'",
     "h5py",
     "kornia",
     "loguru",

+ 8 - 2
romatch/models/matcher.py

@@ -1,5 +1,6 @@
 import os
 import math
+import sys
 import numpy as np
 import torch
 import torch.nn as nn
@@ -46,6 +47,9 @@ class ConvRefiner(nn.Module):
         use_custom_corr=False,
     ):
         super().__init__()
+        if sys.platform != "linux":
+            warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
+            use_custom_corr = False
         self.bn_momentum = bn_momentum
         self.block1 = self.create_block(
             in_dim,
@@ -553,8 +557,10 @@ class RegressionMatcher(nn.Module):
         sample_mode="threshold_balanced",
         upsample_preds=False,
         symmetric=False,
+        sample_thresh=0.05,
         name=None,
         attenuate_cert=None,
+        upsample_res=None,
     ):
         super().__init__()
         self.attenuate_cert = attenuate_cert
@@ -566,9 +572,9 @@ class RegressionMatcher(nn.Module):
         self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
         self.sample_mode = sample_mode
         self.upsample_preds = upsample_preds
-        self.upsample_res = (14 * 16 * 6, 14 * 16 * 6)
+        self.upsample_res = upsample_res or (14 * 16 * 6, 14 * 16 * 6)
         self.symmetric = symmetric
-        self.sample_thresh = 0.05
+        self.sample_thresh = sample_thresh
 
     def get_output_resolution(self):
         if not self.upsample_preds:

+ 10 - 31
romatch/models/model_zoo/__init__.py

@@ -1,7 +1,6 @@
 from typing import Union
 import torch
 from .roma_models import roma_model, tiny_roma_v1_model
-from loguru import logger
 
 
 weight_urls = {
@@ -37,20 +36,9 @@ def roma_outdoor(
     upsample_res: Union[int, tuple[int, int]] = 864,
     amp_dtype: torch.dtype = torch.float16,
     symmetric=True,
-    use_custom_corr=False,
+    use_custom_corr=True,
     upsample_preds=True,
 ):
-    if isinstance(coarse_res, int):
-        coarse_res = (coarse_res, coarse_res)
-    if isinstance(upsample_res, int):
-        upsample_res = (upsample_res, upsample_res)
-
-    if str(device) == "cpu":
-        amp_dtype = torch.float32
-
-    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
-    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
-
     if weights is None:
         weights = torch.hub.load_state_dict_from_url(
             weight_urls["romatch"]["outdoor"], map_location=device
@@ -68,10 +56,7 @@ def roma_outdoor(
         amp_dtype=amp_dtype,
         symmetric=symmetric,
         use_custom_corr=use_custom_corr,
-    )
-    model.upsample_res = upsample_res
-    logger.info(
-        f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}"
+        upsample_res=upsample_res,
     )
     return model
 
@@ -83,15 +68,10 @@ def roma_indoor(
     coarse_res: Union[int, tuple[int, int]] = 560,
     upsample_res: Union[int, tuple[int, int]] = 864,
     amp_dtype: torch.dtype = torch.float16,
+    symmetric=True,
+    use_custom_corr=True,
+    upsample_preds=True,
 ):
-    if isinstance(coarse_res, int):
-        coarse_res = (coarse_res, coarse_res)
-    if isinstance(upsample_res, int):
-        upsample_res = (upsample_res, upsample_res)
-
-    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
-    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
-
     if weights is None:
         weights = torch.hub.load_state_dict_from_url(
             weight_urls["romatch"]["indoor"], map_location=device
@@ -102,14 +82,13 @@ def roma_indoor(
         )
     model = roma_model(
         resolution=coarse_res,
-        upsample_preds=True,
+        upsample_preds=upsample_preds,
         weights=weights,
         dinov2_weights=dinov2_weights,
         device=device,
         amp_dtype=amp_dtype,
+        symmetric=symmetric,
+        use_custom_corr=use_custom_corr,
+        upsample_res=upsample_res,
     )
-    model.upsample_res = upsample_res
-    logger.info(
-        f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}"
-    )
-    return model
+    return model

+ 38 - 8
romatch/models/model_zoo/roma_models.py

@@ -1,17 +1,21 @@
-from functools import partial
+import sys
 import warnings
-import torch.nn as nn
+from functools import partial
+
 import torch
+import torch.nn as nn
+from loguru import logger
+
+from romatch.models.encoders import CNNandDinov2
 from romatch.models.matcher import (
+    GP,
     ConvRefiner,
     CosKernel,
-    GP,
     Decoder,
     RegressionMatcher,
 )
-from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention
-from romatch.models.encoders import CNNandDinov2
 from romatch.models.tiny import TinyRoMa
+from romatch.models.transformer import Block, MemEffAttention, TransformerDecoder
 
 
 def tiny_roma_v1_model(
@@ -32,10 +36,35 @@ def roma_model(
     weights=None,
     dinov2_weights=None,
     amp_dtype: torch.dtype = torch.float16,
-    use_custom_corr=False,
+    use_custom_corr=True,
     symmetric=True,
+    upsample_res=None,
+    sample_thresh=0.05,
+    sample_mode="threshold_balanced",
+    attenuate_cert = True,
     **kwargs,
 ):
+    if sys.platform != "linux":
+        use_custom_corr = False
+        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
+    if isinstance(resolution, int):
+        resolution = (resolution, resolution)
+    if isinstance(upsample_res, int):
+        upsample_res = (upsample_res, upsample_res)
+
+    if str(device) == "cpu":
+        amp_dtype = torch.float32
+
+    assert resolution[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    assert resolution[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+
+    logger.info(
+        f"Using coarse resolution {resolution}, and upsample res {upsample_res}"
+    )
+
+    if sys.platform != "linux":
+        use_custom_corr = False
+        warnings.warn("Local correlation is not supported on non-Linux platforms, setting use_custom_corr to False")
     warnings.filterwarnings(
         "ignore", category=UserWarning, message="TypedStorage is deprecated"
     )
@@ -158,17 +187,18 @@ def roma_model(
         amp_dtype=amp_dtype,
     )
     h, w = resolution
-    attenuate_cert = True
-    sample_mode = "threshold_balanced"
+    
     matcher = RegressionMatcher(
         encoder,
         decoder,
         h=h,
         w=w,
         upsample_preds=upsample_preds,
+        upsample_res=upsample_res,
         symmetric=symmetric,
         attenuate_cert=attenuate_cert,
         sample_mode=sample_mode,
+        sample_thresh=sample_thresh,
         **kwargs,
     ).to(device)
     matcher.load_state_dict(weights)

+ 1 - 1
romatch/utils/local_correlation.py

@@ -1,7 +1,6 @@
 from typing import Literal
 import torch
 import torch.nn.functional as F
-import local_corr
 
 
 def local_corr_wrapper(
@@ -20,6 +19,7 @@ def local_corr_wrapper(
     sample_mode: Literal["bilinear", "nearest"] = "bilinear",
     dtype=torch.float32,
 ):
+    import local_corr
     assert padding_mode == "zeros"
     warp = (coords[..., None, :] + local_window[:, None, None]).reshape(B, h * w, K, 2)
     corr = (

+ 3 - 3
tests/test_mega1500.py

@@ -16,7 +16,7 @@ if __name__ == "__main__":
     # gotten on 3.12 env with torch 2.8.0
     reference_scores = [0.6271474434923545, 0.7673889435429945, 0.8642099162282599] # slightly worse.
     # old_reference_scores = [0.6235757679569996, 0.7648007367330985, 0.8630483724961098]
-    assert np.isclose(results[0], reference_scores[0], atol=3e-1 / 100)
-    assert np.isclose(results[1], reference_scores[1], atol=2e-1 / 100)
-    assert np.isclose(results[2], reference_scores[2], atol=1e-1 / 100)
+    assert np.isclose(results["auc_5"], reference_scores[0], atol=3e-1 / 100)
+    assert np.isclose(results["auc_10"], reference_scores[1], atol=2e-1 / 100)
+    assert np.isclose(results["auc_20"], reference_scores[2], atol=1e-1 / 100)
     

+ 5 - 5
uv.lock

@@ -542,9 +542,9 @@ name = "fused-local-corr"
 version = "0.2.2"
 source = { registry = "https://pypi.org/simple" }
 dependencies = [
-    { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },
-    { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
-    { name = "torch" },
+    { name = "numpy", version = "2.0.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10' and sys_platform == 'linux'" },
+    { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10' and sys_platform == 'linux'" },
+    { name = "torch", marker = "sys_platform == 'linux'" },
 ]
 wheels = [
     { url = "https://files.pythonhosted.org/packages/c0/72/7fd886f0fcd4b66d10c341d78a960adfab0b807ff1e72998bee123ebcb7a/fused_local_corr-0.2.2-cp310-cp310-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:573a766ae4b108491a085454f6044970b36e9e66259b23a0413ee2d7befb018a", size = 254718, upload-time = "2025-09-10T21:22:30.681Z" },
@@ -1930,7 +1930,7 @@ source = { editable = "." }
 dependencies = [
     { name = "albumentations" },
     { name = "einops" },
-    { name = "fused-local-corr" },
+    { name = "fused-local-corr", marker = "sys_platform == 'linux'" },
     { name = "h5py" },
     { name = "kornia" },
     { name = "loguru" },
@@ -1954,7 +1954,7 @@ dev = [
 requires-dist = [
     { name = "albumentations" },
     { name = "einops" },
-    { name = "fused-local-corr", specifier = ">=0.2.2" },
+    { name = "fused-local-corr", marker = "sys_platform == 'linux'", specifier = ">=0.2.2" },
     { name = "h5py" },
     { name = "kornia" },
     { name = "loguru" },