Johan Edstedt 3 лет назад
Родитель
Сommit
3e9d971c78
47 измененных файлов с 4838 добавлено и 2 удалено
  1. 5 0
      .gitignore
  2. 26 2
      README.md
  3. BIN
      assets/sacre_coeur_A.jpg
  4. BIN
      assets/sacre_coeur_B.jpg
  5. 2 0
      data/.gitignore
  6. 33 0
      demo/demo_fundamental.py
  7. 47 0
      demo/demo_match.py
  8. 320 0
      experiments/roma_indoor.py
  9. 323 0
      experiments/roma_outdoor.py
  10. 13 0
      requirements.txt
  11. 8 0
      roma/__init__.py
  12. 4 0
      roma/benchmarks/__init__.py
  13. 113 0
      roma/benchmarks/hpatches_sequences_homog_benchmark.py
  14. 106 0
      roma/benchmarks/megadepth_dense_benchmark.py
  15. 140 0
      roma/benchmarks/megadepth_pose_estimation_benchmark.py
  16. 143 0
      roma/benchmarks/scannet_benchmark.py
  17. 1 0
      roma/checkpointing/__init__.py
  18. 60 0
      roma/checkpointing/checkpoint.py
  19. 2 0
      roma/datasets/__init__.py
  20. 230 0
      roma/datasets/megadepth.py
  21. 160 0
      roma/datasets/scannet.py
  22. 1 0
      roma/losses/__init__.py
  23. 157 0
      roma/losses/robust_loss.py
  24. 1 0
      roma/models/__init__.py
  25. 118 0
      roma/models/encoders.py
  26. 649 0
      roma/models/matcher.py
  27. 30 0
      roma/models/model_zoo/__init__.py
  28. 157 0
      roma/models/model_zoo/roma_models.py
  29. 47 0
      roma/models/transformer/__init__.py
  30. 359 0
      roma/models/transformer/dinov2.py
  31. 12 0
      roma/models/transformer/layers/__init__.py
  32. 81 0
      roma/models/transformer/layers/attention.py
  33. 252 0
      roma/models/transformer/layers/block.py
  34. 59 0
      roma/models/transformer/layers/dino_head.py
  35. 35 0
      roma/models/transformer/layers/drop_path.py
  36. 28 0
      roma/models/transformer/layers/layer_scale.py
  37. 41 0
      roma/models/transformer/layers/mlp.py
  38. 89 0
      roma/models/transformer/layers/patch_embed.py
  39. 63 0
      roma/models/transformer/layers/swiglu_ffn.py
  40. 1 0
      roma/train/__init__.py
  41. 102 0
      roma/train/train.py
  42. 16 0
      roma/utils/__init__.py
  43. 8 0
      roma/utils/kde.py
  44. 47 0
      roma/utils/local_correlation.py
  45. 118 0
      roma/utils/transforms.py
  46. 622 0
      roma/utils/utils.py
  47. 9 0
      setup.py

+ 5 - 0
.gitignore

@@ -0,0 +1,5 @@
+*.egg-info*
+*.vscode*
+*__pycache__*
+vis*
+workspace*

+ 26 - 2
README.md

@@ -1,2 +1,26 @@
-# RoMa
-Soon some exiting stuff will be here, stay posted!
+# RoMa: Revisiting Robust Losses for Dense Feature Matching
+
+**NOTE!!! Very early code, there might be bugs**
+
+The experiments in the paper are provided in the [experiments folder](experiments).
+The codebase is in the [roma folder](roma).
+
+## Setup/Install
+In your python environment (tested on Linux python 3.10), run:
+```bash
+pip install -e .
+```
+
+## Training
+1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
+2. Run the relevant experiment, e.g.,
+```bash
+torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
+```
+## Testing
+```bash
+python experiments/roma_outdoor.py --only_test --benchmark mega-1500
+```
+
+## Acknowledgement
+Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).

BIN
assets/sacre_coeur_A.jpg


BIN
assets/sacre_coeur_B.jpg


+ 2 - 0
data/.gitignore

@@ -0,0 +1,2 @@
+*
+!.gitignore

+ 33 - 0
demo/demo_fundamental.py

@@ -0,0 +1,33 @@
+from PIL import Image
+import torch
+import cv2
+from roma import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+
+    args, _ = parser.parse_known_args()
+    im1_path = args.im_A_path
+    im2_path = args.im_B_path
+
+    # Create model
+    roma_model = roma_outdoor(device=device)
+
+
+    W_A, H_A = Image.open(im1_path).size
+    W_B, H_B = Image.open(im2_path).size
+
+    # Match
+    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+    # Sample matches for estimation
+    matches, certainty = roma_model.sample(warp, certainty)
+    kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)    
+    F, mask = cv2.findFundamentalMat(
+        kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+    )

+ 47 - 0
demo/demo_match.py

@@ -0,0 +1,47 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from roma.utils.utils import tensor_to_pil
+
+from roma import roma_indoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str)
+
+    args, _ = parser.parse_known_args()
+    im1_path = args.im_A_path
+    im2_path = args.im_B_path
+    save_path = args.save_path
+
+    # Create model
+    roma_model = roma_indoor(device=device)
+
+    H, W = roma_model.get_output_resolution()
+
+    im1 = Image.open(im1_path).resize((W, H))
+    im2 = Image.open(im2_path).resize((W, H))
+
+    # Match
+    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+    # Sampling not needed, but can be done with model.sample(warp, certainty)
+    x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+    x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+    im2_transfer_rgb = F.grid_sample(
+    x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+    )[0]
+    im1_transfer_rgb = F.grid_sample(
+    x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+    )[0]
+    warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
+    white_im = torch.ones((H,2*W),device=device)
+    vis_im = certainty * warp_im + (1 - certainty) * white_im
+    tensor_to_pil(vis_im, unnormalize=False).save(save_path)

+ 320 - 0
experiments/roma_indoor.py

@@ -0,0 +1,320 @@
+import os
+import torch
+from argparse import ArgumentParser
+
+from torch import nn
+from torch.utils.data import ConcatDataset
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+import json
+import wandb
+from tqdm import tqdm
+
+from roma.benchmarks import MegadepthDenseBenchmark
+from roma.datasets.megadepth import MegadepthBuilder
+from roma.datasets.scannet import ScanNetBuilder
+from roma.losses.robust_loss import RobustLosses
+from roma.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark
+from roma.train.train import train_k_steps
+from roma.models.matcher import *
+from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
+from roma.models.encoders import *
+from roma.checkpointing import CheckPoint
+
+resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
+
+def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
+    gp_dim = 512
+    feat_dim = 512
+    decoder_dim = gp_dim + feat_dim
+    cls_to_coord_res = 64
+    coordinate_decoder = TransformerDecoder(
+        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
+        decoder_dim, 
+        cls_to_coord_res**2 + 1,
+        is_classifier=True,
+        amp = True,
+        pos_enc = False,)
+    dw = True
+    hidden_blocks = 8
+    kernel_size = 5
+    displacement_emb = "linear"
+    disable_local_corr_grad = True
+    
+    conv_refiner = nn.ModuleDict(
+        {
+            "16": ConvRefiner(
+                2 * 512+128+(2*7+1)**2,
+                2 * 512+128+(2*7+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=128,
+                local_corr_radius = 7,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "8": ConvRefiner(
+                2 * 512+64+(2*3+1)**2,
+                2 * 512+64+(2*3+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=64,
+                local_corr_radius = 3,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "4": ConvRefiner(
+                2 * 256+32+(2*2+1)**2,
+                2 * 256+32+(2*2+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=32,
+                local_corr_radius = 2,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "2": ConvRefiner(
+                2 * 64+16,
+                128+16,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=16,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "1": ConvRefiner(
+                2 * 9 + 6,
+                24,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks = hidden_blocks,
+                displacement_emb = displacement_emb,
+                displacement_emb_dim = 6,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+        }
+    )
+    kernel_temperature = 0.2
+    learn_temperature = False
+    no_cov = True
+    kernel = CosKernel
+    only_attention = False
+    basis = "fourier"
+    gp16 = GP(
+        kernel,
+        T=kernel_temperature,
+        learn_temperature=learn_temperature,
+        only_attention=only_attention,
+        gp_dim=gp_dim,
+        basis=basis,
+        no_cov=no_cov,
+    )
+    gps = nn.ModuleDict({"16": gp16})
+    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+    proj = nn.ModuleDict({
+        "16": proj16,
+        "8": proj8,
+        "4": proj4,
+        "2": proj2,
+        "1": proj1,
+        })
+    displacement_dropout_p = 0.0
+    gm_warp_dropout_p = 0.0
+    decoder = Decoder(coordinate_decoder, 
+                      gps, 
+                      proj, 
+                      conv_refiner, 
+                      detach=True, 
+                      scales=["16", "8", "4", "2", "1"], 
+                      displacement_dropout_p = displacement_dropout_p,
+                      gm_warp_dropout_p = gm_warp_dropout_p)
+    h,w = resolutions[resolution]
+    encoder = CNNandDinov2(
+        cnn_kwargs = dict(
+            pretrained=pretrained_backbone,
+            amp = True),
+        amp = True,
+        use_vgg = True,
+    )
+    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs)
+    return matcher
+
+def train(args):
+    dist.init_process_group('nccl')
+    #torch._dynamo.config.verbose=True
+    gpus = int(os.environ['WORLD_SIZE'])
+    # create model and move it to GPU with id rank
+    rank = dist.get_rank()
+    print(f"Start running DDP on rank {rank}")
+    device_id = rank % torch.cuda.device_count()
+    roma.LOCAL_RANK = device_id
+    torch.cuda.set_device(device_id)
+    
+    resolution = args.train_resolution
+    wandb_log = not args.dont_log_wandb
+    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
+    wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled"
+    wandb.init(project="roma", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
+    checkpoint_dir = "workspace/checkpoints/"
+    h,w = resolutions[resolution]
+    model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
+    # Num steps
+    global_step = 0
+    batch_size = args.gpu_batch_size
+    step_size = gpus*batch_size
+    roma.STEP_SIZE = step_size
+    
+    N = (32 * 250000)  # 250k steps of batch size 32
+    # checkpoint every
+    k = 25000 // roma.STEP_SIZE
+
+    # Data
+    mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
+    use_horizontal_flip_aug = True
+    rot_prob = 0
+    depth_interpolation_mode = "bilinear"
+    megadepth_train1 = mega.build_scenes(
+        split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
+        ht=h,wt=w,
+    )
+    megadepth_train2 = mega.build_scenes(
+        split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
+        ht=h,wt=w,
+    )
+    megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
+    mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
+    
+    scannet = ScanNetBuilder(data_root="data/scannet")
+    scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug)
+    scannet_train = ConcatDataset(scannet_train)
+    scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75)
+
+    # Loss and optimizer
+    depth_loss_scannet = RobustLosses(
+        ce_weight=0.0, 
+        local_dist={1:4, 2:4, 4:8, 8:8},
+        local_largest_scale=8,
+        depth_interpolation_mode=depth_interpolation_mode,
+        alpha = 0.5,
+        c = 1e-4,)
+    # Loss and optimizer
+    depth_loss_mega = RobustLosses(
+        ce_weight=0.01, 
+        local_dist={1:4, 2:4, 4:8, 8:8},
+        local_largest_scale=8,
+        depth_interpolation_mode=depth_interpolation_mode,
+        alpha = 0.5,
+        c = 1e-4,)
+    parameters = [
+        {"params": model.encoder.parameters(), "lr": roma.STEP_SIZE * 5e-6 / 8},
+        {"params": model.decoder.parameters(), "lr": roma.STEP_SIZE * 1e-4 / 8},
+    ]
+    optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
+    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
+        optimizer, milestones=[(9*N/roma.STEP_SIZE)//10])
+    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
+    checkpointer = CheckPoint(checkpoint_dir, experiment_name)
+    model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
+    roma.GLOBAL_STEP = global_step
+    ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
+    grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
+    grad_clip_norm = 0.01
+    for n in range(roma.GLOBAL_STEP, N, k * roma.STEP_SIZE):
+        mega_sampler = torch.utils.data.WeightedRandomSampler(
+            mega_ws, num_samples = batch_size * k, replacement=False
+        )
+        mega_dataloader = iter(
+            torch.utils.data.DataLoader(
+                megadepth_train,
+                batch_size = batch_size,
+                sampler = mega_sampler,
+                num_workers = 8,
+            )
+        )
+        scannet_ws_sampler = torch.utils.data.WeightedRandomSampler(
+            scannet_ws, num_samples=batch_size * k, replacement=False
+        )
+        scannet_dataloader = iter(
+            torch.utils.data.DataLoader(
+                scannet_train,
+                batch_size=batch_size,
+                sampler=scannet_ws_sampler,
+                num_workers=gpus * 8,
+            )
+        )
+        for n_k in tqdm(range(n, n + 2 * k, 2),disable = roma.RANK > 0):
+            train_k_steps(
+                n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
+            )
+            train_k_steps(
+                n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False
+            )
+        checkpointer.save(model, optimizer, lr_scheduler, roma.GLOBAL_STEP)
+        wandb.log(megadense_benchmark.benchmark(model), step = roma.GLOBAL_STEP)
+
+def test_scannet(model, name, resolution, sample_mode):
+    scannet_benchmark = ScanNetBenchmark("data/scannet")
+    scannet_results = scannet_benchmark.benchmark(model)
+    json.dump(scannet_results, open(f"results/scannet_{name}.json", "w"))
+
+if __name__ == "__main__":
+    import warnings
+    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+    warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.')
+    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
+    os.environ["OMP_NUM_THREADS"] = "16"
+    
+    import roma
+    parser = ArgumentParser()
+    parser.add_argument("--test", action='store_true')
+    parser.add_argument("--debug_mode", action='store_true')
+    parser.add_argument("--dont_log_wandb", action='store_true')
+    parser.add_argument("--train_resolution", default='medium')
+    parser.add_argument("--gpu_batch_size", default=4, type=int)
+    parser.add_argument("--wandb_entity", required = False)
+
+    args, _ = parser.parse_known_args()
+    roma.DEBUG_MODE = args.debug_mode
+    if not args.test:
+        train(args)
+    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
+    checkpoint_dir = "workspace/"
+    checkpoint_name = checkpoint_dir + experiment_name + ".pth"
+    test_resolution = "medium"
+    sample_mode = "threshold_balanced"
+    symmetric = True
+    upsample_preds = False
+    attenuate_cert = True
+
+    model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
+    model = model.cuda()
+    states = torch.load(checkpoint_name)
+    model.load_state_dict(states["model"])
+    test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)

+ 323 - 0
experiments/roma_outdoor.py

@@ -0,0 +1,323 @@
+import os
+import torch
+from argparse import ArgumentParser
+
+from torch import nn
+from torch.utils.data import ConcatDataset
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+import json
+import wandb
+
+from roma.benchmarks import MegadepthDenseBenchmark
+from roma.datasets.megadepth import MegadepthBuilder
+from roma.losses.robust_loss import RobustLosses
+from roma.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark
+
+from roma.train.train import train_k_steps
+from roma.models.matcher import *
+from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
+from roma.models.encoders import *
+from roma.checkpointing import CheckPoint
+
+resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)}
+
+def get_model(pretrained_backbone=True, resolution = "medium", **kwargs):
+    import warnings
+    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+    gp_dim = 512
+    feat_dim = 512
+    decoder_dim = gp_dim + feat_dim
+    cls_to_coord_res = 64
+    coordinate_decoder = TransformerDecoder(
+        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
+        decoder_dim, 
+        cls_to_coord_res**2 + 1,
+        is_classifier=True,
+        amp = True,
+        pos_enc = False,)
+    dw = True
+    hidden_blocks = 8
+    kernel_size = 5
+    displacement_emb = "linear"
+    disable_local_corr_grad = True
+    
+    conv_refiner = nn.ModuleDict(
+        {
+            "16": ConvRefiner(
+                2 * 512+128+(2*7+1)**2,
+                2 * 512+128+(2*7+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=128,
+                local_corr_radius = 7,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "8": ConvRefiner(
+                2 * 512+64+(2*3+1)**2,
+                2 * 512+64+(2*3+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=64,
+                local_corr_radius = 3,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "4": ConvRefiner(
+                2 * 256+32+(2*2+1)**2,
+                2 * 256+32+(2*2+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=32,
+                local_corr_radius = 2,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "2": ConvRefiner(
+                2 * 64+16,
+                128+16,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=16,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "1": ConvRefiner(
+                2 * 9 + 6,
+                24,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks = hidden_blocks,
+                displacement_emb = displacement_emb,
+                displacement_emb_dim = 6,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+        }
+    )
+    kernel_temperature = 0.2
+    learn_temperature = False
+    no_cov = True
+    kernel = CosKernel
+    only_attention = False
+    basis = "fourier"
+    gp16 = GP(
+        kernel,
+        T=kernel_temperature,
+        learn_temperature=learn_temperature,
+        only_attention=only_attention,
+        gp_dim=gp_dim,
+        basis=basis,
+        no_cov=no_cov,
+    )
+    gps = nn.ModuleDict({"16": gp16})
+    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+    proj = nn.ModuleDict({
+        "16": proj16,
+        "8": proj8,
+        "4": proj4,
+        "2": proj2,
+        "1": proj1,
+        })
+    displacement_dropout_p = 0.0
+    gm_warp_dropout_p = 0.0
+    decoder = Decoder(coordinate_decoder, 
+                      gps, 
+                      proj, 
+                      conv_refiner, 
+                      detach=True, 
+                      scales=["16", "8", "4", "2", "1"], 
+                      displacement_dropout_p = displacement_dropout_p,
+                      gm_warp_dropout_p = gm_warp_dropout_p)
+    h,w = resolutions[resolution]
+    encoder = CNNandDinov2(
+        cnn_kwargs = dict(
+            pretrained=pretrained_backbone,
+            amp = True),
+        amp = True,
+        use_vgg = True,
+    )
+    matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs)
+    return matcher
+
+def train(args):
+    dist.init_process_group('nccl')
+    #torch._dynamo.config.verbose=True
+    gpus = int(os.environ['WORLD_SIZE'])
+    # create model and move it to GPU with id rank
+    rank = dist.get_rank()
+    print(f"Start running DDP on rank {rank}")
+    device_id = rank % torch.cuda.device_count()
+    roma.LOCAL_RANK = device_id
+    torch.cuda.set_device(device_id)
+    
+    resolution = args.train_resolution
+    wandb_log = not args.dont_log_wandb
+    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
+    wandb_mode = "online" if wandb_log and rank == 0 else "disabled"
+    wandb.init(project="roma", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode)
+    checkpoint_dir = "workspace/checkpoints/"
+    h,w = resolutions[resolution]
+    model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id)
+    # Num steps
+    global_step = 0
+    batch_size = args.gpu_batch_size
+    step_size = gpus*batch_size
+    roma.STEP_SIZE = step_size
+    
+    N = (32 * 250000)  # 250k steps of batch size 32
+    # checkpoint every
+    k = 25000 // roma.STEP_SIZE
+
+    # Data
+    mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True)
+    use_horizontal_flip_aug = True
+    rot_prob = 0
+    depth_interpolation_mode = "bilinear"
+    megadepth_train1 = mega.build_scenes(
+        split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
+        ht=h,wt=w,
+    )
+    megadepth_train2 = mega.build_scenes(
+        split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob,
+        ht=h,wt=w,
+    )
+    megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2)
+    mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75)
+    # Loss and optimizer
+    depth_loss = RobustLosses(
+        ce_weight=0.01, 
+        local_dist={1:4, 2:4, 4:8, 8:8},
+        local_largest_scale=8,
+        depth_interpolation_mode=depth_interpolation_mode,
+        alpha = 0.5,
+        c = 1e-4,)
+    parameters = [
+        {"params": model.encoder.parameters(), "lr": roma.STEP_SIZE * 5e-6 / 8},
+        {"params": model.decoder.parameters(), "lr": roma.STEP_SIZE * 1e-4 / 8},
+    ]
+    optimizer = torch.optim.AdamW(parameters, weight_decay=0.01)
+    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
+        optimizer, milestones=[(9*N/roma.STEP_SIZE)//10])
+    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w)
+    checkpointer = CheckPoint(checkpoint_dir, experiment_name)
+    model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step)
+    roma.GLOBAL_STEP = global_step
+    ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True)
+    grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000)
+    grad_clip_norm = 0.01
+    for n in range(roma.GLOBAL_STEP, N, k * roma.STEP_SIZE):
+        mega_sampler = torch.utils.data.WeightedRandomSampler(
+            mega_ws, num_samples = batch_size * k, replacement=False
+        )
+        mega_dataloader = iter(
+            torch.utils.data.DataLoader(
+                megadepth_train,
+                batch_size = batch_size,
+                sampler = mega_sampler,
+                num_workers = 8,
+            )
+        )
+        train_k_steps(
+            n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm,
+        )
+        checkpointer.save(model, optimizer, lr_scheduler, roma.GLOBAL_STEP)
+        wandb.log(megadense_benchmark.benchmark(model), step = roma.GLOBAL_STEP)
+
+def test_mega_8_scenes(model, name, resolution, sample_mode):
+    mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth",
+                                                scene_names=['mega_8_scenes_0019_0.1_0.3.npz',
+                                                    'mega_8_scenes_0025_0.1_0.3.npz',
+                                                    'mega_8_scenes_0021_0.1_0.3.npz',
+                                                    'mega_8_scenes_0008_0.1_0.3.npz',
+                                                    'mega_8_scenes_0032_0.1_0.3.npz',
+                                                    'mega_8_scenes_1589_0.1_0.3.npz',
+                                                    'mega_8_scenes_0063_0.1_0.3.npz',
+                                                    'mega_8_scenes_0024_0.1_0.3.npz',
+                                                    'mega_8_scenes_0019_0.3_0.5.npz',
+                                                    'mega_8_scenes_0025_0.3_0.5.npz',
+                                                    'mega_8_scenes_0021_0.3_0.5.npz',
+                                                    'mega_8_scenes_0008_0.3_0.5.npz',
+                                                    'mega_8_scenes_0032_0.3_0.5.npz',
+                                                    'mega_8_scenes_1589_0.3_0.5.npz',
+                                                    'mega_8_scenes_0063_0.3_0.5.npz',
+                                                    'mega_8_scenes_0024_0.3_0.5.npz'])
+    mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name, scale_intrinsics = False)
+    print(mega_8_scenes_results)
+    json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w"))
+
+def test_mega1500(model, name, resolution, sample_mode):
+    mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth")
+    mega1500_results = mega1500_benchmark.benchmark(model, model_name=name)
+    json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w"))
+
+def test_mega_dense(model, name, resolution, sample_mode):
+    megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000)
+    megadense_results = megadense_benchmark.benchmark(model)
+    json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w"))
+    
+def test_hpatches(model, name, resolution, sample_mode):
+    hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches")
+    hpatches_results = hpatches_benchmark.benchmark(model)
+    json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w"))
+
+
+if __name__ == "__main__":
+    os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations
+    os.environ["OMP_NUM_THREADS"] = "16"
+    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
+    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
+    import roma
+    parser = ArgumentParser()
+    parser.add_argument("--only_test", action='store_true')
+    parser.add_argument("--debug_mode", action='store_true')
+    parser.add_argument("--dont_log_wandb", action='store_true')
+    parser.add_argument("--train_resolution", default='medium')
+    parser.add_argument("--gpu_batch_size", default=4, type=int)
+    parser.add_argument("--wandb_entity", required = False)
+
+    args, _ = parser.parse_known_args()
+    roma.DEBUG_MODE = args.debug_mode
+    if not args.only_test:
+        train(args)
+    experiment_name = os.path.splitext(os.path.basename(__file__))[0]
+    checkpoint_dir = "workspace/checkpoints/"
+    checkpoint_name = checkpoint_dir + experiment_name + ".pth"
+    
+    test_resolution = "high"
+    sample_mode = "threshold_balanced"
+    symmetric = True
+    upsample_preds = True
+    attenuate_cert = True
+
+    model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert)
+    model = model.cuda()
+    weights = torch.load(checkpoint_name)
+    model.load_state_dict(weights)
+    test_mega1500(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode)

+ 13 - 0
requirements.txt

@@ -0,0 +1,13 @@
+torch
+einops
+torchvision
+opencv-python
+kornia
+albumentations
+loguru
+tqdm
+matplotlib
+h5py
+wandb
+timm
+xformers # Optional, used for memefficient attention

+ 8 - 0
roma/__init__.py

@@ -0,0 +1,8 @@
+import os
+from .models import roma_outdoor, roma_indoor
+
+DEBUG_MODE = False
+RANK = int(os.environ.get('RANK', default = 0))
+GLOBAL_STEP = 0
+STEP_SIZE = 1
+LOCAL_RANK = -1

+ 4 - 0
roma/benchmarks/__init__.py

@@ -0,0 +1,4 @@
+from .hpatches_sequences_homog_benchmark import HpatchesHomogBenchmark
+from .scannet_benchmark import ScanNetBenchmark
+from .megadepth_pose_estimation_benchmark import MegaDepthPoseEstimationBenchmark
+from .megadepth_dense_benchmark import MegadepthDenseBenchmark

+ 113 - 0
roma/benchmarks/hpatches_sequences_homog_benchmark.py

@@ -0,0 +1,113 @@
+from PIL import Image
+import numpy as np
+
+import os
+
+from tqdm import tqdm
+from roma.utils import pose_auc
+import cv2
+
+
+class HpatchesHomogBenchmark:
+    """Hpatches grid goes from [0,n-1] instead of [0.5,n-0.5]"""
+
+    def __init__(self, dataset_path) -> None:
+        seqs_dir = "hpatches-sequences-release"
+        self.seqs_path = os.path.join(dataset_path, seqs_dir)
+        self.seq_names = sorted(os.listdir(self.seqs_path))
+        # Ignore seqs is same as LoFTR.
+        self.ignore_seqs = set(
+            [
+                "i_contruction",
+                "i_crownnight",
+                "i_dc",
+                "i_pencils",
+                "i_whitebuilding",
+                "v_artisans",
+                "v_astronautis",
+                "v_talent",
+            ]
+        )
+
+    def convert_coordinates(self, im_A_coords, im_A_to_im_B, wq, hq, wsup, hsup):
+        offset = 0.5  # Hpatches assumes that the center of the top-left pixel is at [0,0] (I think)
+        im_A_coords = (
+            np.stack(
+                (
+                    wq * (im_A_coords[..., 0] + 1) / 2,
+                    hq * (im_A_coords[..., 1] + 1) / 2,
+                ),
+                axis=-1,
+            )
+            - offset
+        )
+        im_A_to_im_B = (
+            np.stack(
+                (
+                    wsup * (im_A_to_im_B[..., 0] + 1) / 2,
+                    hsup * (im_A_to_im_B[..., 1] + 1) / 2,
+                ),
+                axis=-1,
+            )
+            - offset
+        )
+        return im_A_coords, im_A_to_im_B
+
+    def benchmark(self, model, model_name = None):
+        n_matches = []
+        homog_dists = []
+        for seq_idx, seq_name in tqdm(
+            enumerate(self.seq_names), total=len(self.seq_names)
+        ):
+            im_A_path = os.path.join(self.seqs_path, seq_name, "1.ppm")
+            im_A = Image.open(im_A_path)
+            w1, h1 = im_A.size
+            for im_idx in range(2, 7):
+                im_B_path = os.path.join(self.seqs_path, seq_name, f"{im_idx}.ppm")
+                im_B = Image.open(im_B_path)
+                w2, h2 = im_B.size
+                H = np.loadtxt(
+                    os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
+                )
+                dense_matches, dense_certainty = model.match(
+                    im_A_path, im_B_path
+                )
+                good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
+                pos_a, pos_b = self.convert_coordinates(
+                    good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
+                )
+                try:
+                    H_pred, inliers = cv2.findHomography(
+                        pos_a,
+                        pos_b,
+                        method = cv2.RANSAC,
+                        confidence = 0.99999,
+                        ransacReprojThreshold = 3 * min(w2, h2) / 480,
+                    )
+                except:
+                    H_pred = None
+                if H_pred is None:
+                    H_pred = np.zeros((3, 3))
+                    H_pred[2, 2] = 1.0
+                corners = np.array(
+                    [[0, 0, 1], [0, h1 - 1, 1], [w1 - 1, 0, 1], [w1 - 1, h1 - 1, 1]]
+                )
+                real_warped_corners = np.dot(corners, np.transpose(H))
+                real_warped_corners = (
+                    real_warped_corners[:, :2] / real_warped_corners[:, 2:]
+                )
+                warped_corners = np.dot(corners, np.transpose(H_pred))
+                warped_corners = warped_corners[:, :2] / warped_corners[:, 2:]
+                mean_dist = np.mean(
+                    np.linalg.norm(real_warped_corners - warped_corners, axis=1)
+                ) / (min(w2, h2) / 480.0)
+                homog_dists.append(mean_dist)
+
+        n_matches = np.array(n_matches)
+        thresholds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+        auc = pose_auc(np.array(homog_dists), thresholds)
+        return {
+            "hpatches_homog_auc_3": auc[2],
+            "hpatches_homog_auc_5": auc[4],
+            "hpatches_homog_auc_10": auc[9],
+        }

+ 106 - 0
roma/benchmarks/megadepth_dense_benchmark.py

@@ -0,0 +1,106 @@
+import torch
+import numpy as np
+import tqdm
+from roma.datasets import MegadepthBuilder
+from roma.utils import warp_kpts
+from torch.utils.data import ConcatDataset
+import roma
+
+class MegadepthDenseBenchmark:
+    def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
+        mega = MegadepthBuilder(data_root=data_root)
+        self.dataset = ConcatDataset(
+            mega.build_scenes(split="test_loftr", ht=h, wt=w)
+        )  # fixed resolution of 384,512
+        self.num_samples = num_samples
+
+    def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches):
+        b, h1, w1, d = dense_matches.shape
+        with torch.no_grad():
+            x1 = dense_matches[..., :2].reshape(b, h1 * w1, 2)
+            mask, x2 = warp_kpts(
+                x1.double(),
+                depth1.double(),
+                depth2.double(),
+                T_1to2.double(),
+                K1.double(),
+                K2.double(),
+            )
+            x2 = torch.stack(
+                (w1 * (x2[..., 0] + 1) / 2, h1 * (x2[..., 1] + 1) / 2), dim=-1
+            )
+            prob = mask.float().reshape(b, h1, w1)
+        x2_hat = dense_matches[..., 2:]
+        x2_hat = torch.stack(
+            (w1 * (x2_hat[..., 0] + 1) / 2, h1 * (x2_hat[..., 1] + 1) / 2), dim=-1
+        )
+        gd = (x2_hat - x2.reshape(b, h1, w1, 2)).norm(dim=-1)
+        gd = gd[prob == 1]
+        pck_1 = (gd < 1.0).float().mean()
+        pck_3 = (gd < 3.0).float().mean()
+        pck_5 = (gd < 5.0).float().mean()
+        return gd, pck_1, pck_3, pck_5, prob
+
+    def benchmark(self, model, batch_size=8):
+        model.train(False)
+        with torch.no_grad():
+            gd_tot = 0.0
+            pck_1_tot = 0.0
+            pck_3_tot = 0.0
+            pck_5_tot = 0.0
+            sampler = torch.utils.data.WeightedRandomSampler(
+                torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
+            )
+            B = batch_size
+            dataloader = torch.utils.data.DataLoader(
+                self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
+            )
+            for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0):
+                im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
+                    data["im_A"],
+                    data["im_B"],
+                    data["im_A_depth"].cuda(),
+                    data["im_B_depth"].cuda(),
+                    data["T_1to2"].cuda(),
+                    data["K1"].cuda(),
+                    data["K2"].cuda(),
+                )
+                matches, certainty = model.match(im_A, im_B, batched=True)
+                gd, pck_1, pck_3, pck_5, prob = self.geometric_dist(
+                    depth1, depth2, T_1to2, K1, K2, matches
+                )
+                if roma.DEBUG_MODE:
+                    from roma.utils.utils import tensor_to_pil
+                    import torch.nn.functional as F
+                    path = "vis"
+                    H, W = model.get_output_resolution()
+                    white_im = torch.ones((B,1,H,W),device="cuda")
+                    im_B_transfer_rgb = F.grid_sample(
+                        im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
+                    )
+                    warp_im = im_B_transfer_rgb
+                    c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
+                    vis_im = c_b * warp_im + (1 - c_b) * white_im
+                    for b in range(B):
+                        import os
+                        os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
+                        tensor_to_pil(vis_im[b], unnormalize=True).save(
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
+                        tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
+                        tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
+
+
+                gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
+                    gd_tot + gd.mean(),
+                    pck_1_tot + pck_1,
+                    pck_3_tot + pck_3,
+                    pck_5_tot + pck_5,
+                )
+        return {
+            "epe": gd_tot.item() / len(dataloader),
+            "mega_pck_1": pck_1_tot.item() / len(dataloader),
+            "mega_pck_3": pck_3_tot.item() / len(dataloader),
+            "mega_pck_5": pck_5_tot.item() / len(dataloader),
+        }

+ 140 - 0
roma/benchmarks/megadepth_pose_estimation_benchmark.py

@@ -0,0 +1,140 @@
+import numpy as np
+import torch
+from roma.utils import *
+from PIL import Image
+from tqdm import tqdm
+import torch.nn.functional as F
+import roma
+import kornia.geometry.epipolar as kepi
+
+class MegaDepthPoseEstimationBenchmark:
+    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
+        if scene_names is None:
+            self.scene_names = [
+                "0015_0.1_0.3.npz",
+                "0015_0.3_0.5.npz",
+                "0022_0.1_0.3.npz",
+                "0022_0.3_0.5.npz",
+                "0022_0.5_0.7.npz",
+            ]
+        else:
+            self.scene_names = scene_names
+        self.scenes = [
+            np.load(f"{data_root}/{scene}", allow_pickle=True)
+            for scene in self.scene_names
+        ]
+        self.data_root = data_root
+
+    def benchmark(self, model, model_name = None, resolution = None, scale_intrinsics = True, calibrated = True):
+        H,W = model.get_output_resolution()
+        with torch.no_grad():
+            data_root = self.data_root
+            tot_e_t, tot_e_R, tot_e_pose = [], [], []
+            thresholds = [5, 10, 20]
+            for scene_ind in range(len(self.scenes)):
+                import os
+                scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
+                scene = self.scenes[scene_ind]
+                pairs = scene["pair_infos"]
+                intrinsics = scene["intrinsics"]
+                poses = scene["poses"]
+                im_paths = scene["image_paths"]
+                pair_inds = range(len(pairs))
+                for pairind in tqdm(pair_inds):
+                    idx1, idx2 = pairs[pairind][0]
+                    K1 = intrinsics[idx1].copy()
+                    T1 = poses[idx1].copy()
+                    R1, t1 = T1[:3, :3], T1[:3, 3]
+                    K2 = intrinsics[idx2].copy()
+                    T2 = poses[idx2].copy()
+                    R2, t2 = T2[:3, :3], T2[:3, 3]
+                    R, t = compute_relative_pose(R1, t1, R2, t2)
+                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
+                    im_A_path = f"{data_root}/{im_paths[idx1]}"
+                    im_B_path = f"{data_root}/{im_paths[idx2]}"
+                    dense_matches, dense_certainty = model.match(
+                        im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
+                    )
+                    sparse_matches,_ = model.sample(
+                        dense_matches, dense_certainty, 5000
+                    )
+                    
+                    im_A = Image.open(im_A_path)
+                    w1, h1 = im_A.size
+                    im_B = Image.open(im_B_path)
+                    w2, h2 = im_B.size
+
+                    if scale_intrinsics:
+                        scale1 = 1200 / max(w1, h1)
+                        scale2 = 1200 / max(w2, h2)
+                        w1, h1 = scale1 * w1, scale1 * h1
+                        w2, h2 = scale2 * w2, scale2 * h2
+                        K1, K2 = K1.copy(), K2.copy()
+                        K1[:2] = K1[:2] * scale1
+                        K2[:2] = K2[:2] * scale2
+
+                    kpts1 = sparse_matches[:, :2]
+                    kpts1 = (
+                        np.stack(
+                            (
+                                w1 * (kpts1[:, 0] + 1) / 2,
+                                h1 * (kpts1[:, 1] + 1) / 2,
+                            ),
+                            axis=-1,
+                        )
+                    )
+                    kpts2 = sparse_matches[:, 2:]
+                    kpts2 = (
+                        np.stack(
+                            (
+                                w2 * (kpts2[:, 0] + 1) / 2,
+                                h2 * (kpts2[:, 1] + 1) / 2,
+                            ),
+                            axis=-1,
+                        )
+                    )
+
+                    for _ in range(5):
+                        shuffling = np.random.permutation(np.arange(len(kpts1)))
+                        kpts1 = kpts1[shuffling]
+                        kpts2 = kpts2[shuffling]
+                        try:
+                            threshold = 0.5 
+                            if calibrated:
+                                norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                                R_est, t_est, mask = estimate_pose(
+                                    kpts1,
+                                    kpts2,
+                                    K1,
+                                    K2,
+                                    norm_threshold,
+                                    conf=0.99999,
+                                )
+                            T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
+                            e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+                            e_pose = max(e_t, e_R)
+                        except Exception as e:
+                            print(repr(e))
+                            e_t, e_R = 90, 90
+                            e_pose = max(e_t, e_R)
+                        tot_e_t.append(e_t)
+                        tot_e_R.append(e_R)
+                        tot_e_pose.append(e_pose)
+            tot_e_pose = np.array(tot_e_pose)
+            auc = pose_auc(tot_e_pose, thresholds)
+            acc_5 = (tot_e_pose < 5).mean()
+            acc_10 = (tot_e_pose < 10).mean()
+            acc_15 = (tot_e_pose < 15).mean()
+            acc_20 = (tot_e_pose < 20).mean()
+            map_5 = acc_5
+            map_10 = np.mean([acc_5, acc_10])
+            map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+            print(f"{model_name} auc: {auc}")
+            return {
+                "auc_5": auc[0],
+                "auc_10": auc[1],
+                "auc_20": auc[2],
+                "map_5": map_5,
+                "map_10": map_10,
+                "map_20": map_20,
+            }

+ 143 - 0
roma/benchmarks/scannet_benchmark.py

@@ -0,0 +1,143 @@
+import os.path as osp
+import numpy as np
+import torch
+from roma.utils import *
+from PIL import Image
+from tqdm import tqdm
+
+
+class ScanNetBenchmark:
+    def __init__(self, data_root="data/scannet") -> None:
+        self.data_root = data_root
+
+    def benchmark(self, model, model_name = None):
+        model.train(False)
+        with torch.no_grad():
+            data_root = self.data_root
+            tmp = np.load(osp.join(data_root, "test.npz"))
+            pairs, rel_pose = tmp["name"], tmp["rel_pose"]
+            tot_e_t, tot_e_R, tot_e_pose = [], [], []
+            pair_inds = np.random.choice(
+                range(len(pairs)), size=len(pairs), replace=False
+            )
+            for pairind in tqdm(pair_inds, smoothing=0.9):
+                scene = pairs[pairind]
+                scene_name = f"scene0{scene[0]}_00"
+                im_A_path = osp.join(
+                        self.data_root,
+                        "scans_test",
+                        scene_name,
+                        "color",
+                        f"{scene[2]}.jpg",
+                    )
+                im_A = Image.open(im_A_path)
+                im_B_path = osp.join(
+                        self.data_root,
+                        "scans_test",
+                        scene_name,
+                        "color",
+                        f"{scene[3]}.jpg",
+                    )
+                im_B = Image.open(im_B_path)
+                T_gt = rel_pose[pairind].reshape(3, 4)
+                R, t = T_gt[:3, :3], T_gt[:3, 3]
+                K = np.stack(
+                    [
+                        np.array([float(i) for i in r.split()])
+                        for r in open(
+                            osp.join(
+                                self.data_root,
+                                "scans_test",
+                                scene_name,
+                                "intrinsic",
+                                "intrinsic_color.txt",
+                            ),
+                            "r",
+                        )
+                        .read()
+                        .split("\n")
+                        if r
+                    ]
+                )
+                w1, h1 = im_A.size
+                w2, h2 = im_B.size
+                K1 = K.copy()
+                K2 = K.copy()
+                dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
+                sparse_matches, sparse_certainty = model.sample(
+                    dense_matches, dense_certainty, 5000
+                )
+                scale1 = 480 / min(w1, h1)
+                scale2 = 480 / min(w2, h2)
+                w1, h1 = scale1 * w1, scale1 * h1
+                w2, h2 = scale2 * w2, scale2 * h2
+                K1 = K1 * scale1
+                K2 = K2 * scale2
+
+                offset = 0.5
+                kpts1 = sparse_matches[:, :2]
+                kpts1 = (
+                    np.stack(
+                        (
+                            w1 * (kpts1[:, 0] + 1) / 2 - offset,
+                            h1 * (kpts1[:, 1] + 1) / 2 - offset,
+                        ),
+                        axis=-1,
+                    )
+                )
+                kpts2 = sparse_matches[:, 2:]
+                kpts2 = (
+                    np.stack(
+                        (
+                            w2 * (kpts2[:, 0] + 1) / 2 - offset,
+                            h2 * (kpts2[:, 1] + 1) / 2 - offset,
+                        ),
+                        axis=-1,
+                    )
+                )
+                for _ in range(5):
+                    shuffling = np.random.permutation(np.arange(len(kpts1)))
+                    kpts1 = kpts1[shuffling]
+                    kpts2 = kpts2[shuffling]
+                    try:
+                        norm_threshold = 0.5 / (
+                        np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                        R_est, t_est, mask = estimate_pose(
+                            kpts1,
+                            kpts2,
+                            K1,
+                            K2,
+                            norm_threshold,
+                            conf=0.99999,
+                        )
+                        T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
+                        e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
+                        e_pose = max(e_t, e_R)
+                    except Exception as e:
+                        print(repr(e))
+                        e_t, e_R = 90, 90
+                        e_pose = max(e_t, e_R)
+                    tot_e_t.append(e_t)
+                    tot_e_R.append(e_R)
+                    tot_e_pose.append(e_pose)
+                tot_e_t.append(e_t)
+                tot_e_R.append(e_R)
+                tot_e_pose.append(e_pose)
+            tot_e_pose = np.array(tot_e_pose)
+            thresholds = [5, 10, 20]
+            auc = pose_auc(tot_e_pose, thresholds)
+            acc_5 = (tot_e_pose < 5).mean()
+            acc_10 = (tot_e_pose < 10).mean()
+            acc_15 = (tot_e_pose < 15).mean()
+            acc_20 = (tot_e_pose < 20).mean()
+            map_5 = acc_5
+            map_10 = np.mean([acc_5, acc_10])
+            map_20 = np.mean([acc_5, acc_10, acc_15, acc_20])
+            return {
+                "auc_5": auc[0],
+                "auc_10": auc[1],
+                "auc_20": auc[2],
+                "map_5": map_5,
+                "map_10": map_10,
+                "map_20": map_20,
+            }

+ 1 - 0
roma/checkpointing/__init__.py

@@ -0,0 +1 @@
+from .checkpoint import CheckPoint

+ 60 - 0
roma/checkpointing/checkpoint.py

@@ -0,0 +1,60 @@
+import os
+import torch
+from torch.nn.parallel.data_parallel import DataParallel
+from torch.nn.parallel.distributed import DistributedDataParallel
+from loguru import logger
+import gc
+
+import roma
+
+class CheckPoint:
+    def __init__(self, dir=None, name="tmp"):
+        self.name = name
+        self.dir = dir
+        os.makedirs(self.dir, exist_ok=True)
+
+    def save(
+        self,
+        model,
+        optimizer,
+        lr_scheduler,
+        n,
+        ):
+        if roma.RANK == 0:
+            assert model is not None
+            if isinstance(model, (DataParallel, DistributedDataParallel)):
+                model = model.module
+            states = {
+                "model": model.state_dict(),
+                "n": n,
+                "optimizer": optimizer.state_dict(),
+                "lr_scheduler": lr_scheduler.state_dict(),
+            }
+            torch.save(states, self.dir + self.name + f"_latest.pth")
+            logger.info(f"Saved states {list(states.keys())}, at step {n}")
+    
+    def load(
+        self,
+        model,
+        optimizer,
+        lr_scheduler,
+        n,
+        ):
+        if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
+            states = torch.load(self.dir + self.name + f"_latest.pth")
+            if "model" in states:
+                model.load_state_dict(states["model"])
+            if "n" in states:
+                n = states["n"] if states["n"] else n
+            if "optimizer" in states:
+                try:
+                    optimizer.load_state_dict(states["optimizer"])
+                except Exception as e:
+                    print(f"Failed to load states for optimizer, with error {e}")
+            if "lr_scheduler" in states:
+                lr_scheduler.load_state_dict(states["lr_scheduler"])
+            print(f"Loaded states {list(states.keys())}, at step {n}")
+            del states
+            gc.collect()
+            torch.cuda.empty_cache()
+        return model, optimizer, lr_scheduler, n

+ 2 - 0
roma/datasets/__init__.py

@@ -0,0 +1,2 @@
+from .megadepth import MegadepthBuilder
+from .scannet import ScanNetBuilder

+ 230 - 0
roma/datasets/megadepth.py

@@ -0,0 +1,230 @@
+import os
+from PIL import Image
+import h5py
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+import roma
+from roma.utils import *
+import math
+
+class MegadepthScene:
+    def __init__(
+        self,
+        data_root,
+        scene_info,
+        ht=384,
+        wt=512,
+        min_overlap=0.0,
+        max_overlap=1.0,
+        shake_t=0,
+        rot_prob=0.0,
+        normalize=True,
+        max_num_pairs = 100_000,
+        scene_name = None,
+        use_horizontal_flip_aug = False,
+        use_single_horizontal_flip_aug = False,
+        colorjiggle_params = None,
+        random_eraser = None,
+        use_randaug = False,
+        randaug_params = None,
+        randomize_size = False,
+    ) -> None:
+        self.data_root = data_root
+        self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
+        self.image_paths = scene_info["image_paths"]
+        self.depth_paths = scene_info["depth_paths"]
+        self.intrinsics = scene_info["intrinsics"]
+        self.poses = scene_info["poses"]
+        self.pairs = scene_info["pairs"]
+        self.overlaps = scene_info["overlaps"]
+        threshold = (self.overlaps > min_overlap) & (self.overlaps < max_overlap)
+        self.pairs = self.pairs[threshold]
+        self.overlaps = self.overlaps[threshold]
+        if len(self.pairs) > max_num_pairs:
+            pairinds = np.random.choice(
+                np.arange(0, len(self.pairs)), max_num_pairs, replace=False
+            )
+            self.pairs = self.pairs[pairinds]
+            self.overlaps = self.overlaps[pairinds]
+        if randomize_size:
+            area = ht * wt
+            s = int(16 * (math.sqrt(area)//16))
+            sizes = ((ht,wt), (s,s), (wt,ht))
+            choice = roma.RANK % 3
+            ht, wt = sizes[choice] 
+        # counts, bins = np.histogram(self.overlaps,20)
+        # print(counts)
+        self.im_transform_ops = get_tuple_transform_ops(
+            resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
+        )
+        self.depth_transform_ops = get_depth_tuple_transform_ops(
+                resize=(ht, wt)
+            )
+        self.wt, self.ht = wt, ht
+        self.shake_t = shake_t
+        self.random_eraser = random_eraser
+        if use_horizontal_flip_aug and use_single_horizontal_flip_aug:
+            raise ValueError("Can't both flip both images and only flip one")
+        self.use_horizontal_flip_aug = use_horizontal_flip_aug
+        self.use_single_horizontal_flip_aug = use_single_horizontal_flip_aug
+        self.use_randaug = use_randaug
+
+    def load_im(self, im_path):
+        im = Image.open(im_path)
+        return im
+    
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
+        im_A = im_A.flip(-1)
+        im_B = im_B.flip(-1)
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
+        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+        K_A = flip_mat@K_A  
+        K_B = flip_mat@K_B  
+        
+        return im_A, im_B, depth_A, depth_B, K_A, K_B
+    
+    def load_depth(self, depth_ref, crop=None):
+        depth = np.array(h5py.File(depth_ref, "r")["depth"])
+        return torch.from_numpy(depth)
+
+    def __len__(self):
+        return len(self.pairs)
+
+    def scale_intrinsic(self, K, wi, hi):
+        sx, sy = self.wt / wi, self.ht / hi
+        sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
+        return sK @ K
+
+    def rand_shake(self, *things):
+        t = np.random.choice(range(-self.shake_t, self.shake_t + 1), size=2)
+        return [
+            tvf.affine(thing, angle=0.0, translate=list(t), scale=1.0, shear=[0.0, 0.0])
+            for thing in things
+        ], t
+
+    def __getitem__(self, pair_idx):
+        # read intrinsics of original size
+        idx1, idx2 = self.pairs[pair_idx]
+        K1 = torch.tensor(self.intrinsics[idx1].copy(), dtype=torch.float).reshape(3, 3)
+        K2 = torch.tensor(self.intrinsics[idx2].copy(), dtype=torch.float).reshape(3, 3)
+
+        # read and compute relative poses
+        T1 = self.poses[idx1]
+        T2 = self.poses[idx2]
+        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
+            :4, :4
+        ]  # (4, 4)
+
+        # Load positive pair data
+        im_A, im_B = self.image_paths[idx1], self.image_paths[idx2]
+        depth1, depth2 = self.depth_paths[idx1], self.depth_paths[idx2]
+        im_A_ref = os.path.join(self.data_root, im_A)
+        im_B_ref = os.path.join(self.data_root, im_B)
+        depth_A_ref = os.path.join(self.data_root, depth1)
+        depth_B_ref = os.path.join(self.data_root, depth2)
+        im_A = self.load_im(im_A_ref)
+        im_B = self.load_im(im_B_ref)
+        K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+        K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+
+        if self.use_randaug:
+            im_A, im_B = self.rand_augment(im_A, im_B)
+
+        depth_A = self.load_depth(depth_A_ref)
+        depth_B = self.load_depth(depth_B_ref)
+        # Process images
+        im_A, im_B = self.im_transform_ops((im_A, im_B))
+        depth_A, depth_B = self.depth_transform_ops(
+            (depth_A[None, None], depth_B[None, None])
+        )
+        
+        [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
+        K1[:2, 2] += t
+        K2[:2, 2] += t
+        
+        im_A, im_B = im_A[None], im_B[None]
+        if self.random_eraser is not None:
+            im_A, depth_A = self.random_eraser(im_A, depth_A)
+            im_B, depth_B = self.random_eraser(im_B, depth_B)
+                
+        if self.use_horizontal_flip_aug:
+            if np.random.rand() > 0.5:
+                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+        if self.use_single_horizontal_flip_aug:
+            if np.random.rand() > 0.5:
+                im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
+        
+        if roma.DEBUG_MODE:
+            tensor_to_pil(im_A[0], unnormalize=True).save(
+                            f"vis/im_A.jpg")
+            tensor_to_pil(im_B[0], unnormalize=True).save(
+                            f"vis/im_B.jpg")
+            
+        data_dict = {
+            "im_A": im_A[0],
+            "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
+            "im_B": im_B[0],
+            "im_B_identifier": self.image_paths[idx2].split("/")[-1].split(".jpg")[0],
+            "im_A_depth": depth_A[0, 0],
+            "im_B_depth": depth_B[0, 0],
+            "K1": K1,
+            "K2": K2,
+            "T_1to2": T_1to2,
+            "im_A_path": im_A_ref,
+            "im_B_path": im_B_ref,
+            
+        }
+        return data_dict
+
+
+class MegadepthBuilder:
+    def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
+        self.data_root = data_root
+        self.scene_info_root = os.path.join(data_root, "prep_scene_info")
+        self.all_scenes = os.listdir(self.scene_info_root)
+        self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
+        # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
+        self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
+        self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
+        self.test_scenes_loftr = ["0015.npy", "0022.npy"]
+        self.loftr_ignore = loftr_ignore
+        self.imc21_ignore = imc21_ignore
+
+    def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
+        if split == "train":
+            scene_names = set(self.all_scenes) - set(self.test_scenes)
+        elif split == "train_loftr":
+            scene_names = set(self.all_scenes) - set(self.test_scenes_loftr)
+        elif split == "test":
+            scene_names = self.test_scenes
+        elif split == "test_loftr":
+            scene_names = self.test_scenes_loftr
+        elif split == "custom":
+            scene_names = scene_names
+        else:
+            raise ValueError(f"Split {split} not available")
+        scenes = []
+        for scene_name in scene_names:
+            if self.loftr_ignore and scene_name in self.loftr_ignore_scenes:
+                continue
+            if self.imc21_ignore and scene_name in self.imc21_scenes:
+                continue
+            scene_info = np.load(
+                os.path.join(self.scene_info_root, scene_name), allow_pickle=True
+            ).item()
+            scenes.append(
+                MegadepthScene(
+                    self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
+                )
+            )
+        return scenes
+
+    def weight_scenes(self, concat_dataset, alpha=0.5):
+        ns = []
+        for d in concat_dataset.datasets:
+            ns.append(len(d))
+        ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
+        return ws

+ 160 - 0
roma/datasets/scannet.py

@@ -0,0 +1,160 @@
+import os
+import random
+from PIL import Image
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import (
+    Dataset,
+    DataLoader,
+    ConcatDataset)
+
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+import os.path as osp
+import matplotlib.pyplot as plt
+import roma
+from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+from roma.utils.transforms import GeometricSequential
+from tqdm import tqdm
+
+class ScanNetScene:
+    def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
+) -> None:
+        self.scene_root = osp.join(data_root,"scans","scans_train")
+        self.data_names = scene_info['name']
+        self.overlaps = scene_info['score']
+        # Only sample 10s
+        valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+        self.overlaps = self.overlaps[valid]
+        self.data_names = self.data_names[valid]
+        if len(self.data_names) > 10000:
+            pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+            self.data_names = self.data_names[pairinds]
+            self.overlaps = self.overlaps[pairinds]
+        self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
+        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+        self.wt, self.ht = wt, ht
+        self.shake_t = shake_t
+        self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+        self.use_horizontal_flip_aug = use_horizontal_flip_aug
+
+    def load_im(self, im_B, crop=None):
+        im = Image.open(im_B)
+        return im
+    
+    def load_depth(self, depth_ref, crop=None):
+        depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
+        depth = depth / 1000
+        depth = torch.from_numpy(depth).float()  # (h, w)
+        return depth
+
+    def __len__(self):
+        return len(self.data_names)
+    
+    def scale_intrinsic(self, K, wi, hi):
+        sx, sy = self.wt / wi, self.ht /  hi
+        sK = torch.tensor([[sx, 0, 0],
+                        [0, sy, 0],
+                        [0, 0, 1]])
+        return sK@K
+
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
+        im_A = im_A.flip(-1)
+        im_B = im_B.flip(-1)
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
+        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+        K_A = flip_mat@K_A  
+        K_B = flip_mat@K_B  
+        
+        return im_A, im_B, depth_A, depth_B, K_A, K_B
+    def read_scannet_pose(self,path):
+        """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+        
+        Returns:
+            pose_w2c (np.ndarray): (4, 4)
+        """
+        cam2world = np.loadtxt(path, delimiter=' ')
+        world2cam = np.linalg.inv(cam2world)
+        return world2cam
+
+
+    def read_scannet_intrinsic(self,path):
+        """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+        """
+        intrinsic = np.loadtxt(path, delimiter=' ')
+        return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
+
+    def __getitem__(self, pair_idx):
+        # read intrinsics of original size
+        data_name = self.data_names[pair_idx]
+        scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
+        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+        
+        # read the intrinsic of depthmap
+        K1 = K2 =  self.read_scannet_intrinsic(osp.join(self.scene_root,
+                       scene_name,
+                       'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+        # read and compute relative poses
+        T1 =  self.read_scannet_pose(osp.join(self.scene_root,
+                       scene_name,
+                       'pose', f'{stem_name_1}.txt'))
+        T2 =  self.read_scannet_pose(osp.join(self.scene_root,
+                       scene_name,
+                       'pose', f'{stem_name_2}.txt'))
+        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4]  # (4, 4)
+
+        # Load positive pair data
+        im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
+        im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
+        depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
+        depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+
+        im_A = self.load_im(im_A_ref)
+        im_B = self.load_im(im_B_ref)
+        depth_A = self.load_depth(depth_A_ref)
+        depth_B = self.load_depth(depth_B_ref)
+
+        # Recompute camera intrinsic matrix due to the resize
+        K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+        K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+        # Process images
+        im_A, im_B = self.im_transform_ops((im_A, im_B))
+        depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
+        if self.use_horizontal_flip_aug:
+            if np.random.rand() > 0.5:
+                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+
+        data_dict = {'im_A': im_A,
+                    'im_B': im_B,
+                    'im_A_depth': depth_A[0,0],
+                    'im_B_depth': depth_B[0,0],
+                    'K1': K1,
+                    'K2': K2,
+                    'T_1to2':T_1to2,
+                    }
+        return data_dict
+
+
+class ScanNetBuilder:
+    def __init__(self, data_root = 'data/scannet') -> None:
+        self.data_root = data_root
+        self.scene_info_root = os.path.join(data_root,'scannet_indices')
+        self.all_scenes = os.listdir(self.scene_info_root)
+        
+    def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+        # Note: split doesn't matter here as we always use same scannet_train scenes
+        scene_names = self.all_scenes
+        scenes = []
+        for scene_name in tqdm(scene_names, disable = roma.RANK > 0):
+            scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
+            scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+        return scenes
+    
+    def weight_scenes(self, concat_dataset, alpha=.5):
+        ns = []
+        for d in concat_dataset.datasets:
+            ns.append(len(d))
+        ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+        return ws

+ 1 - 0
roma/losses/__init__.py

@@ -0,0 +1 @@
+from .robust_loss import RobustLosses

+ 157 - 0
roma/losses/robust_loss.py

@@ -0,0 +1,157 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from roma.utils.utils import get_gt_warp
+import wandb
+import roma
+import math
+
+class RobustLosses(nn.Module):
+    def __init__(
+        self,
+        robust=False,
+        center_coords=False,
+        scale_normalize=False,
+        ce_weight=0.01,
+        local_loss=True,
+        local_dist=4.0,
+        local_largest_scale=8,
+        smooth_mask = False,
+        depth_interpolation_mode = "bilinear",
+        mask_depth_loss = False,
+        relative_depth_error_threshold = 0.05,
+        alpha = 1.,
+        c = 1e-3,
+    ):
+        super().__init__()
+        self.robust = robust  # measured in pixels
+        self.center_coords = center_coords
+        self.scale_normalize = scale_normalize
+        self.ce_weight = ce_weight
+        self.local_loss = local_loss
+        self.local_dist = local_dist
+        self.local_largest_scale = local_largest_scale
+        self.smooth_mask = smooth_mask
+        self.depth_interpolation_mode = depth_interpolation_mode
+        self.mask_depth_loss = mask_depth_loss
+        self.relative_depth_error_threshold = relative_depth_error_threshold
+        self.avg_overlap = dict()
+        self.alpha = alpha
+        self.c = c
+
+    def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
+        with torch.no_grad():
+            B, C, H, W = scale_gm_cls.shape
+            device = x2.device
+            cls_res = round(math.sqrt(C))
+            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
+            GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
+        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction  = 'none')[prob > 0.99]
+        if not torch.any(cls_loss):
+            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+
+        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
+        losses = {
+            f"gm_certainty_loss_{scale}": certainty_loss.mean(),
+            f"gm_cls_loss_{scale}": cls_loss.mean(),
+        }
+        wandb.log(losses, step = roma.GLOBAL_STEP)
+        return losses
+
+    def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
+        with torch.no_grad():
+            B, C, H, W = delta_cls.shape
+            device = x2.device
+            cls_res = round(math.sqrt(C))
+            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
+            GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
+        cls_loss = F.cross_entropy(delta_cls, GT, reduction  = 'none')[prob > 0.99]
+        if not torch.any(cls_loss):
+            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
+        losses = {
+            f"delta_certainty_loss_{scale}": certainty_loss.mean(),
+            f"delta_cls_loss_{scale}": cls_loss.mean(),
+        }
+        wandb.log(losses, step = roma.GLOBAL_STEP)
+        return losses
+
+    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
+        epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
+        if scale == 1:
+            pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
+            wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP)
+
+        ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
+        a = self.alpha
+        cs = self.c * scale
+        x = epe[prob > 0.99]
+        reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
+        if not torch.any(reg_loss):
+            reg_loss = (ce_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+        losses = {
+            f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
+            f"{mode}_regression_loss_{scale}": reg_loss.mean(),
+        }
+        wandb.log(losses, step = roma.GLOBAL_STEP)
+        return losses
+
+    def forward(self, corresps, batch):
+        scales = list(corresps.keys())
+        tot_loss = 0.0
+        # scale_weights due to differences in scale for regression gradients and classification gradients
+        scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
+        for scale in scales:
+            scale_corresps = corresps[scale]
+            scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
+                scale_corresps["certainty"],
+                scale_corresps["flow_pre_delta"],
+                scale_corresps.get("delta_cls"),
+                scale_corresps.get("offset_scale"),
+                scale_corresps.get("gm_cls"),
+                scale_corresps.get("gm_certainty"),
+                scale_corresps["flow"],
+                scale_corresps.get("gm_flow"),
+
+            )
+            flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
+            b, h, w, d = flow_pre_delta.shape
+            gt_warp, gt_prob = get_gt_warp(                
+            batch["im_A_depth"],
+            batch["im_B_depth"],
+            batch["T_1to2"],
+            batch["K1"],
+            batch["K2"],
+            H=h,
+            W=w,
+        )
+            x2 = gt_warp.float()
+            prob = gt_prob
+            
+            if self.local_largest_scale >= scale:
+                prob = prob * (
+                        F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
+                        < (2 / 512) * (self.local_dist[scale] * scale))
+            
+            if scale_gm_cls is not None:
+                gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
+                gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * gm_loss
+            elif scale_gm_flow is not None:
+                gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
+                gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * gm_loss
+            
+            if delta_cls is not None:
+                delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
+                delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
+            else:
+                delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
+                reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * reg_loss
+            prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
+        return tot_loss

+ 1 - 0
roma/models/__init__.py

@@ -0,0 +1 @@
+from .model_zoo import roma_outdoor, roma_indoor

+ 118 - 0
roma/models/encoders.py

@@ -0,0 +1,118 @@
+from typing import Optional, Union
+import torch
+from torch import device
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.models as tvm
+import gc
+
+
+class ResNet50(nn.Module):
+    def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
+        super().__init__()
+        if dilation is None:
+            dilation = [False,False,False]
+        if anti_aliased:
+            pass
+        else:
+            if weights is not None:
+                self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
+            else:
+                self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
+            
+        self.high_res = high_res
+        self.freeze_bn = freeze_bn
+        self.early_exit = early_exit
+        self.amp = amp
+        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+
+    def forward(self, x, **kwargs):
+        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+            net = self.net
+            feats = {1:x}
+            x = net.conv1(x)
+            x = net.bn1(x)
+            x = net.relu(x)
+            feats[2] = x 
+            x = net.maxpool(x)
+            x = net.layer1(x)
+            feats[4] = x 
+            x = net.layer2(x)
+            feats[8] = x
+            if self.early_exit:
+                return feats
+            x = net.layer3(x)
+            feats[16] = x
+            x = net.layer4(x)
+            feats[32] = x
+            return feats
+
+    def train(self, mode=True):
+        super().train(mode)
+        if self.freeze_bn:
+            for m in self.modules():
+                if isinstance(m, nn.BatchNorm2d):
+                    m.eval()
+                pass
+
+class VGG19(nn.Module):
+    def __init__(self, pretrained=False, amp = False) -> None:
+        super().__init__()
+        self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
+        self.amp = amp
+        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+
+    def forward(self, x, **kwargs):
+        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+            feats = {}
+            scale = 1
+            for layer in self.layers:
+                if isinstance(layer, nn.MaxPool2d):
+                    feats[scale] = x
+                    scale = scale*2
+                x = layer(x)
+            return feats
+
+class CNNandDinov2(nn.Module):
+    def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
+        super().__init__()
+        if dinov2_weights is None:
+            dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
+        from .transformer import vit_large
+        vit_kwargs = dict(img_size= 518,
+            patch_size= 14,
+            init_values = 1.0,
+            ffn_layer = "mlp",
+            block_chunks = 0,
+        )
+
+        dinov2_vitl14 = vit_large(**vit_kwargs).eval()
+        dinov2_vitl14.load_state_dict(dinov2_weights)
+        cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
+        if not use_vgg:
+            self.cnn = ResNet50(**cnn_kwargs)
+        else:
+            self.cnn = VGG19(**cnn_kwargs)
+        self.amp = amp
+        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        if self.amp:
+            dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
+        self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
+    
+    
+    def train(self, mode: bool = True):
+        return self.cnn.train(mode)
+    
+    def forward(self, x, upsample = False):
+        B,C,H,W = x.shape
+        feature_pyramid = self.cnn(x)
+        
+        if not upsample:
+            with torch.no_grad():
+                if self.dinov2_vitl14[0].device != x.device:
+                    self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
+                dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
+                features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
+                del dinov2_features_16
+                feature_pyramid[16] = features_16
+        return feature_pyramid

+ 649 - 0
roma/models/matcher.py

@@ -0,0 +1,649 @@
+import os
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+import warnings
+from warnings import warn
+
+import roma
+from roma.utils import get_tuple_transform_ops
+from roma.utils.local_correlation import local_correlation
+from roma.utils.utils import cls_to_flow_refine
+from roma.utils.kde import kde
+
+class ConvRefiner(nn.Module):
+    def __init__(
+        self,
+        in_dim=6,
+        hidden_dim=16,
+        out_dim=2,
+        dw=False,
+        kernel_size=5,
+        hidden_blocks=3,
+        displacement_emb = None,
+        displacement_emb_dim = None,
+        local_corr_radius = None,
+        corr_in_other = None,
+        no_im_B_fm = False,
+        amp = False,
+        concat_logits = False,
+        use_bias_block_1 = True,
+        use_cosine_corr = False,
+        disable_local_corr_grad = False,
+        is_classifier = False,
+        sample_mode = "bilinear",
+        norm_type = nn.BatchNorm2d,
+        bn_momentum = 0.1,
+    ):
+        super().__init__()
+        self.bn_momentum = bn_momentum
+        self.block1 = self.create_block(
+            in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
+        )
+        self.hidden_blocks = nn.Sequential(
+            *[
+                self.create_block(
+                    hidden_dim,
+                    hidden_dim,
+                    dw=dw,
+                    kernel_size=kernel_size,
+                    norm_type=norm_type,
+                )
+                for hb in range(hidden_blocks)
+            ]
+        )
+        self.hidden_blocks = self.hidden_blocks
+        self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
+        if displacement_emb:
+            self.has_displacement_emb = True
+            self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
+        else:
+            self.has_displacement_emb = False
+        self.local_corr_radius = local_corr_radius
+        self.corr_in_other = corr_in_other
+        self.no_im_B_fm = no_im_B_fm
+        self.amp = amp
+        self.concat_logits = concat_logits
+        self.use_cosine_corr = use_cosine_corr
+        self.disable_local_corr_grad = disable_local_corr_grad
+        self.is_classifier = is_classifier
+        self.sample_mode = sample_mode
+        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        
+    def create_block(
+        self,
+        in_dim,
+        out_dim,
+        dw=False,
+        kernel_size=5,
+        bias = True,
+        norm_type = nn.BatchNorm2d,
+    ):
+        num_groups = 1 if not dw else in_dim
+        if dw:
+            assert (
+                out_dim % in_dim == 0
+            ), "outdim must be divisible by indim for depthwise"
+        conv1 = nn.Conv2d(
+            in_dim,
+            out_dim,
+            kernel_size=kernel_size,
+            stride=1,
+            padding=kernel_size // 2,
+            groups=num_groups,
+            bias=bias,
+        )
+        norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
+        relu = nn.ReLU(inplace=True)
+        conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
+        return nn.Sequential(conv1, norm, relu, conv2)
+        
+    def forward(self, x, y, flow, scale_factor = 1, logits = None):
+        b,c,hs,ws = x.shape
+        with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
+            with torch.no_grad():
+                x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
+            if self.has_displacement_emb:
+                im_A_coords = torch.meshgrid(
+                (
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                )
+                )
+                im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
+                im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
+                in_displacement = flow-im_A_coords
+                emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
+                if self.local_corr_radius:
+                    if self.corr_in_other:
+                        # Corr in other means take a kxk grid around the predicted coordinate in other image
+                        local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow, 
+                                                       sample_mode = self.sample_mode)
+                    else:
+                        raise NotImplementedError("Local corr in own frame should not be used.")
+                    if self.no_im_B_fm:
+                        x_hat = torch.zeros_like(x)
+                    d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
+                else:    
+                    d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
+            else:
+                if self.no_im_B_fm:
+                    x_hat = torch.zeros_like(x)
+                d = torch.cat((x, x_hat), dim=1)
+            if self.concat_logits:
+                d = torch.cat((d, logits), dim=1)
+            d = self.block1(d)
+            d = self.hidden_blocks(d)
+        d = self.out_conv(d.float())
+        displacement, certainty = d[:, :-1], d[:, -1:]
+        return displacement, certainty
+
+class CosKernel(nn.Module):  # similar to softmax kernel
+    def __init__(self, T, learn_temperature=False):
+        super().__init__()
+        self.learn_temperature = learn_temperature
+        if self.learn_temperature:
+            self.T = nn.Parameter(torch.tensor(T))
+        else:
+            self.T = T
+
+    def __call__(self, x, y, eps=1e-6):
+        c = torch.einsum("bnd,bmd->bnm", x, y) / (
+            x.norm(dim=-1)[..., None] * y.norm(dim=-1)[:, None] + eps
+        )
+        if self.learn_temperature:
+            T = self.T.abs() + 0.01
+        else:
+            T = torch.tensor(self.T, device=c.device)
+        K = ((c - 1.0) / T).exp()
+        return K
+
+class GP(nn.Module):
+    def __init__(
+        self,
+        kernel,
+        T=1,
+        learn_temperature=False,
+        only_attention=False,
+        gp_dim=64,
+        basis="fourier",
+        covar_size=5,
+        only_nearest_neighbour=False,
+        sigma_noise=0.1,
+        no_cov=False,
+        predict_features = False,
+    ):
+        super().__init__()
+        self.K = kernel(T=T, learn_temperature=learn_temperature)
+        self.sigma_noise = sigma_noise
+        self.covar_size = covar_size
+        self.pos_conv = torch.nn.Conv2d(2, gp_dim, 1, 1)
+        self.only_attention = only_attention
+        self.only_nearest_neighbour = only_nearest_neighbour
+        self.basis = basis
+        self.no_cov = no_cov
+        self.dim = gp_dim
+        self.predict_features = predict_features
+
+    def get_local_cov(self, cov):
+        K = self.covar_size
+        b, h, w, h, w = cov.shape
+        hw = h * w
+        cov = F.pad(cov, 4 * (K // 2,))  # pad v_q
+        delta = torch.stack(
+            torch.meshgrid(
+                torch.arange(-(K // 2), K // 2 + 1), torch.arange(-(K // 2), K // 2 + 1)
+            ),
+            dim=-1,
+        )
+        positions = torch.stack(
+            torch.meshgrid(
+                torch.arange(K // 2, h + K // 2), torch.arange(K // 2, w + K // 2)
+            ),
+            dim=-1,
+        )
+        neighbours = positions[:, :, None, None, :] + delta[None, :, :]
+        points = torch.arange(hw)[:, None].expand(hw, K**2)
+        local_cov = cov.reshape(b, hw, h + K - 1, w + K - 1)[
+            :,
+            points.flatten(),
+            neighbours[..., 0].flatten(),
+            neighbours[..., 1].flatten(),
+        ].reshape(b, h, w, K**2)
+        return local_cov
+
+    def reshape(self, x):
+        return rearrange(x, "b d h w -> b (h w) d")
+
+    def project_to_basis(self, x):
+        if self.basis == "fourier":
+            return torch.cos(8 * math.pi * self.pos_conv(x))
+        elif self.basis == "linear":
+            return self.pos_conv(x)
+        else:
+            raise ValueError(
+                "No other bases other than fourier and linear currently im_Bed in public release"
+            )
+
+    def get_pos_enc(self, y):
+        b, c, h, w = y.shape
+        coarse_coords = torch.meshgrid(
+            (
+                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=y.device),
+                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=y.device),
+            )
+        )
+
+        coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+            None
+        ].expand(b, h, w, 2)
+        coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+        coarse_embedded_coords = self.project_to_basis(coarse_coords)
+        return coarse_embedded_coords
+
+    def forward(self, x, y, **kwargs):
+        b, c, h1, w1 = x.shape
+        b, c, h2, w2 = y.shape
+        f = self.get_pos_enc(y)
+        b, d, h2, w2 = f.shape
+        x, y, f = self.reshape(x.float()), self.reshape(y.float()), self.reshape(f)
+        K_xx = self.K(x, x)
+        K_yy = self.K(y, y)
+        K_xy = self.K(x, y)
+        K_yx = K_xy.permute(0, 2, 1)
+        sigma_noise = self.sigma_noise * torch.eye(h2 * w2, device=x.device)[None, :, :]
+        with warnings.catch_warnings():
+            K_yy_inv = torch.linalg.inv(K_yy + sigma_noise)
+
+        mu_x = K_xy.matmul(K_yy_inv.matmul(f))
+        mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
+        if not self.no_cov:
+            cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
+            cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
+            local_cov_x = self.get_local_cov(cov_x)
+            local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
+            gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
+        else:
+            gp_feats = mu_x
+        return gp_feats
+
+class Decoder(nn.Module):
+    def __init__(
+        self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
+        num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
+        flow_upsample_mode = "bilinear"
+    ):
+        super().__init__()
+        self.embedding_decoder = embedding_decoder
+        self.num_refinement_steps_per_scale = num_refinement_steps_per_scale
+        self.gps = gps
+        self.proj = proj
+        self.conv_refiner = conv_refiner
+        self.detach = detach
+        if pos_embeddings is None:
+            self.pos_embeddings = {}
+        else:
+            self.pos_embeddings = pos_embeddings
+        if scales == "all":
+            self.scales = ["32", "16", "8", "4", "2", "1"]
+        else:
+            self.scales = scales
+        self.warp_noise_std = warp_noise_std
+        self.refine_init = 4
+        self.displacement_dropout_p = displacement_dropout_p
+        self.gm_warp_dropout_p = gm_warp_dropout_p
+        self.flow_upsample_mode = flow_upsample_mode
+        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        
+    def get_placeholder_flow(self, b, h, w, device):
+        coarse_coords = torch.meshgrid(
+            (
+                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+            )
+        )
+        coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+            None
+        ].expand(b, h, w, 2)
+        coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+        return coarse_coords
+    
+    def get_positional_embedding(self, b, h ,w, device):
+        coarse_coords = torch.meshgrid(
+            (
+                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
+                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
+            )
+        )
+
+        coarse_coords = torch.stack((coarse_coords[1], coarse_coords[0]), dim=-1)[
+            None
+        ].expand(b, h, w, 2)
+        coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
+        coarse_embedded_coords = self.pos_embedding(coarse_coords)
+        return coarse_embedded_coords
+
+    def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
+        coarse_scales = self.embedding_decoder.scales()
+        all_scales = self.scales if not upsample else ["8", "4", "2", "1"] 
+        sizes = {scale: f1[scale].shape[-2:] for scale in f1}
+        h, w = sizes[1]
+        b = f1[1].shape[0]
+        device = f1[1].device
+        coarsest_scale = int(all_scales[0])
+        old_stuff = torch.zeros(
+            b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
+        )
+        corresps = {}
+        if not upsample:
+            flow = self.get_placeholder_flow(b, *sizes[coarsest_scale], device)
+            certainty = 0.0
+        else:
+            flow = F.interpolate(
+                    flow,
+                    size=sizes[coarsest_scale],
+                    align_corners=False,
+                    mode="bilinear",
+                )
+            certainty = F.interpolate(
+                    certainty,
+                    size=sizes[coarsest_scale],
+                    align_corners=False,
+                    mode="bilinear",
+                )
+        displacement = 0.0
+        for new_scale in all_scales:
+            ins = int(new_scale)
+            corresps[ins] = {}
+            f1_s, f2_s = f1[ins], f2[ins]
+            if new_scale in self.proj:
+                with torch.autocast("cuda", self.amp_dtype):
+                    f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
+
+            if ins in coarse_scales:
+                old_stuff = F.interpolate(
+                    old_stuff, size=sizes[ins], mode="bilinear", align_corners=False
+                )
+                gp_posterior = self.gps[new_scale](f1_s, f2_s)
+                gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
+                    gp_posterior, f1_s, old_stuff, new_scale
+                )
+                
+                if self.embedding_decoder.is_classifier:
+                    flow = cls_to_flow_refine(
+                        gm_warp_or_cls,
+                    ).permute(0,3,1,2)
+                    corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+                else:
+                    corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
+                    flow = gm_warp_or_cls.detach()
+                    
+            if new_scale in self.conv_refiner:
+                corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
+                delta_flow, delta_certainty = self.conv_refiner[new_scale](
+                    f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
+                )                    
+                corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
+                displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
+                                                delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
+                flow = flow + displacement
+                certainty = (
+                    certainty + delta_certainty
+                )  # predict both certainty and displacement
+            corresps[ins].update({
+                "certainty": certainty,
+                "flow": flow,             
+            })
+            if new_scale != "1":
+                flow = F.interpolate(
+                    flow,
+                    size=sizes[ins // 2],
+                    mode=self.flow_upsample_mode,
+                )
+                certainty = F.interpolate(
+                    certainty,
+                    size=sizes[ins // 2],
+                    mode=self.flow_upsample_mode,
+                )
+                if self.detach:
+                    flow = flow.detach()
+                    certainty = certainty.detach()
+            #torch.cuda.empty_cache()                
+        return corresps
+
+
+class RegressionMatcher(nn.Module):
+    def __init__(
+        self,
+        encoder,
+        decoder,
+        h=448,
+        w=448,
+        sample_mode = "threshold",
+        upsample_preds = False,
+        symmetric = False,
+        name = None,
+        attenuate_cert = None,
+    ):
+        super().__init__()
+        self.attenuate_cert = attenuate_cert
+        self.encoder = encoder
+        self.decoder = decoder
+        self.name = name
+        self.w_resized = w
+        self.h_resized = h
+        self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
+        self.sample_mode = sample_mode
+        self.upsample_preds = upsample_preds
+        self.upsample_res = (14*16*6, 14*16*6)
+        self.symmetric = symmetric
+        self.sample_thresh = 0.05
+            
+    def get_output_resolution(self):
+        if not self.upsample_preds:
+            return self.h_resized, self.w_resized
+        else:
+            return self.upsample_res
+    
+    def extract_backbone_features(self, batch, batched = True, upsample = False):
+        x_q = batch["im_A"]
+        x_s = batch["im_B"]
+        if batched:
+            X = torch.cat((x_q, x_s), dim = 0)
+            feature_pyramid = self.encoder(X, upsample = upsample)
+        else:
+            feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
+        return feature_pyramid
+
+    def sample(
+        self,
+        matches,
+        certainty,
+        num=10000,
+    ):
+        if "threshold" in self.sample_mode:
+            upper_thresh = self.sample_thresh
+            certainty = certainty.clone()
+            certainty[certainty > upper_thresh] = 1
+        matches, certainty = (
+            matches.reshape(-1, 4),
+            certainty.reshape(-1),
+        )
+        expansion_factor = 4 if "balanced" in self.sample_mode else 1
+        good_samples = torch.multinomial(certainty, 
+                          num_samples = min(expansion_factor*num, len(certainty)), 
+                          replacement=False)
+        good_matches, good_certainty = matches[good_samples], certainty[good_samples]
+        if "balanced" not in self.sample_mode:
+            return good_matches, good_certainty
+        density = kde(good_matches, std=0.1)
+        p = 1 / (density+1)
+        p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
+        balanced_samples = torch.multinomial(p, 
+                          num_samples = min(num,len(good_certainty)), 
+                          replacement=False)
+        return good_matches[balanced_samples], good_certainty[balanced_samples]
+
+    def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
+        feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
+        if batched:
+            f_q_pyramid = {
+                scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
+            }
+            f_s_pyramid = {
+                scale: f_scale.chunk(2)[1] for scale, f_scale in feature_pyramid.items()
+            }
+        else:
+            f_q_pyramid, f_s_pyramid = feature_pyramid
+        corresps = self.decoder(f_q_pyramid, 
+                                f_s_pyramid, 
+                                upsample = upsample, 
+                                **(batch["corresps"] if "corresps" in batch else {}),
+                                scale_factor=scale_factor)
+        
+        return corresps
+
+    def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
+        feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
+        f_q_pyramid = feature_pyramid
+        f_s_pyramid = {
+            scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
+            for scale, f_scale in feature_pyramid.items()
+        }
+        corresps = self.decoder(f_q_pyramid, 
+                                f_s_pyramid, 
+                                upsample = upsample, 
+                                **(batch["corresps"] if "corresps" in batch else {}),
+                                scale_factor=scale_factor)
+        return corresps
+    
+    def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
+        kpts_A, kpts_B = matches[...,:2], matches[...,2:]
+        kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
+        kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
+        return kpts_A, kpts_B
+
+    def match(
+        self,
+        im_A_path,
+        im_B_path,
+        *args,
+        batched=False,
+        device = None,
+    ):
+        if device is None:
+            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+        from PIL import Image
+        if isinstance(im_A_path, (str, os.PathLike)):
+            im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+        else:
+            # Assume its not a path
+            im_A, im_B = im_A_path, im_B_path
+        symmetric = self.symmetric
+        self.train(False)
+        with torch.no_grad():
+            if not batched:
+                b = 1
+                w, h = im_A.size
+                w2, h2 = im_B.size
+                # Get images in good format
+                ws = self.w_resized
+                hs = self.h_resized
+                
+                test_transform = get_tuple_transform_ops(
+                    resize=(hs, ws), normalize=True, clahe = False
+                )
+                im_A, im_B = test_transform((im_A, im_B))
+                batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
+            else:
+                b, c, h, w = im_A.shape
+                b, c, h2, w2 = im_B.shape
+                assert w == w2 and h == h2, "For batched images we assume same size"
+                batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
+                if h != self.h_resized or self.w_resized != w:
+                    warn("Model resolution and batch resolution differ, may produce unexpected results")
+                hs, ws = h, w
+            finest_scale = 1
+            # Run matcher
+            if symmetric:
+                corresps  = self.forward_symmetric(batch)
+            else:
+                corresps = self.forward(batch, batched = True)
+
+            if self.upsample_preds:
+                hs, ws = self.upsample_res
+            
+            if self.attenuate_cert:
+                low_res_certainty = F.interpolate(
+                corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
+                )
+                cert_clamp = 0
+                factor = 0.5
+                low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
+
+            if self.upsample_preds:
+                finest_corresps = corresps[finest_scale]
+                torch.cuda.empty_cache()
+                test_transform = get_tuple_transform_ops(
+                    resize=(hs, ws), normalize=True
+                )
+                im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+                im_A, im_B = test_transform((im_A, im_B))
+                im_A, im_B = im_A[None].to(device), im_B[None].to(device)
+                scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
+                batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
+                if symmetric:
+                    corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
+                else:
+                    corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
+            
+            im_A_to_im_B = corresps[finest_scale]["flow"] 
+            certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
+            if finest_scale != 1:
+                im_A_to_im_B = F.interpolate(
+                im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
+                )
+                certainty = F.interpolate(
+                certainty, size=(hs, ws), align_corners=False, mode="bilinear"
+                )
+            im_A_to_im_B = im_A_to_im_B.permute(
+                0, 2, 3, 1
+                )
+            # Create im_A meshgrid
+            im_A_coords = torch.meshgrid(
+                (
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
+                )
+            )
+            im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
+            im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
+            certainty = certainty.sigmoid()  # logits -> probs
+            im_A_coords = im_A_coords.permute(0, 2, 3, 1)
+            if (im_A_to_im_B.abs() > 1).any() and True:
+                wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
+                certainty[wrong[:,None]] = 0
+            im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
+            if symmetric:
+                A_to_B, B_to_A = im_A_to_im_B.chunk(2)
+                q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
+                im_B_coords = im_A_coords
+                s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
+                warp = torch.cat((q_warp, s_warp),dim=2)
+                certainty = torch.cat(certainty.chunk(2), dim=3)
+            else:
+                warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
+            if batched:
+                return (
+                    warp,
+                    certainty[:, 0]
+                )
+            else:
+                return (
+                    warp[0],
+                    certainty[0, 0],
+                )
+

+ 30 - 0
roma/models/model_zoo/__init__.py

@@ -0,0 +1,30 @@
+import torch
+from .roma_models import roma_model
+
+weight_urls = {
+    "roma": {
+        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
+        "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
+    },
+    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
+}
+
+def roma_outdoor(device, weights=None, dinov2_weights=None):
+    if weights is None:
+        weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
+                                                     map_location=device)
+    if dinov2_weights is None:
+        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+                                                     map_location=device)
+    return roma_model(resolution=(14*8*6,14*8*6), upsample_preds=True,
+               weights=weights,dinov2_weights = dinov2_weights,device=device)
+
+def roma_indoor(device, weights=None, dinov2_weights=None):
+    if weights is None:
+        weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
+                                                     map_location=device)
+    if dinov2_weights is None:
+        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+                                                     map_location=device)
+    return roma_model(resolution=(14*8*5,14*8*5), upsample_preds=False,
+               weights=weights,dinov2_weights = dinov2_weights,device=device)

+ 157 - 0
roma/models/model_zoo/roma_models.py

@@ -0,0 +1,157 @@
+import warnings
+import torch.nn as nn
+from roma.models.matcher import *
+from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
+from roma.models.encoders import *
+
+def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, **kwargs):
+    # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
+    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
+    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
+    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+    gp_dim = 512
+    feat_dim = 512
+    decoder_dim = gp_dim + feat_dim
+    cls_to_coord_res = 64
+    coordinate_decoder = TransformerDecoder(
+        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
+        decoder_dim, 
+        cls_to_coord_res**2 + 1,
+        is_classifier=True,
+        amp = True,
+        pos_enc = False,)
+    dw = True
+    hidden_blocks = 8
+    kernel_size = 5
+    displacement_emb = "linear"
+    disable_local_corr_grad = True
+    
+    conv_refiner = nn.ModuleDict(
+        {
+            "16": ConvRefiner(
+                2 * 512+128+(2*7+1)**2,
+                2 * 512+128+(2*7+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=128,
+                local_corr_radius = 7,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "8": ConvRefiner(
+                2 * 512+64+(2*3+1)**2,
+                2 * 512+64+(2*3+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=64,
+                local_corr_radius = 3,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "4": ConvRefiner(
+                2 * 256+32+(2*2+1)**2,
+                2 * 256+32+(2*2+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=32,
+                local_corr_radius = 2,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "2": ConvRefiner(
+                2 * 64+16,
+                128+16,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=16,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "1": ConvRefiner(
+                2 * 9 + 6,
+                24,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks = hidden_blocks,
+                displacement_emb = displacement_emb,
+                displacement_emb_dim = 6,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+        }
+    )
+    kernel_temperature = 0.2
+    learn_temperature = False
+    no_cov = True
+    kernel = CosKernel
+    only_attention = False
+    basis = "fourier"
+    gp16 = GP(
+        kernel,
+        T=kernel_temperature,
+        learn_temperature=learn_temperature,
+        only_attention=only_attention,
+        gp_dim=gp_dim,
+        basis=basis,
+        no_cov=no_cov,
+    )
+    gps = nn.ModuleDict({"16": gp16})
+    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+    proj = nn.ModuleDict({
+        "16": proj16,
+        "8": proj8,
+        "4": proj4,
+        "2": proj2,
+        "1": proj1,
+        })
+    displacement_dropout_p = 0.0
+    gm_warp_dropout_p = 0.0
+    decoder = Decoder(coordinate_decoder, 
+                      gps, 
+                      proj, 
+                      conv_refiner, 
+                      detach=True, 
+                      scales=["16", "8", "4", "2", "1"], 
+                      displacement_dropout_p = displacement_dropout_p,
+                      gm_warp_dropout_p = gm_warp_dropout_p)
+    
+    encoder = CNNandDinov2(
+        cnn_kwargs = dict(
+            pretrained=False,
+            amp = True),
+        amp = True,
+        use_vgg = True,
+        dinov2_weights = dinov2_weights
+    )
+    h,w = resolution
+    symmetric = True
+    attenuate_cert = True
+    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, 
+                                symmetric = symmetric, attenuate_cert=attenuate_cert, **kwargs).to(device)
+    matcher.load_state_dict(weights)
+    return matcher

+ 47 - 0
roma/models/transformer/__init__.py

@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from roma.utils.utils import get_grid
+from .layers.block import Block
+from .layers.attention import MemEffAttention
+from .dinov2 import vit_large
+
+class TransformerDecoder(nn.Module):
+    def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, 
+                 amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.blocks = blocks
+        self.to_out = nn.Linear(hidden_dim, out_dim)
+        self.hidden_dim = hidden_dim
+        self.out_dim = out_dim
+        self._scales = [16]
+        self.is_classifier = is_classifier
+        self.amp = amp
+        self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.pos_enc = pos_enc
+        self.learned_embeddings = learned_embeddings
+        if self.learned_embeddings:
+            self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
+
+    def scales(self):
+        return self._scales.copy()
+
+    def forward(self, gp_posterior, features, old_stuff, new_scale):
+        with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
+            B,C,H,W = gp_posterior.shape
+            x = torch.cat((gp_posterior, features), dim = 1)
+            B,C,H,W = x.shape
+            grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
+            if self.learned_embeddings:
+                pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
+            else:
+                pos_enc = 0
+            tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
+            z = self.blocks(tokens)
+            out = self.to_out(z)
+            out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
+            warp, certainty = out[:, :-1], out[:, -1:]
+            return warp, certainty, None
+
+

+ 359 - 0
roma/models/transformer/dinov2.py

@@ -0,0 +1,359 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+from functools import partial
+import math
+import logging
+from typing import Sequence, Tuple, Union, Callable
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn.init import trunc_normal_
+
+from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
+    if not depth_first and include_root:
+        fn(module=module, name=name)
+    for child_name, child_module in module.named_children():
+        child_name = ".".join((name, child_name)) if name else child_name
+        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
+    if depth_first and include_root:
+        fn(module=module, name=name)
+    return module
+
+
+class BlockChunk(nn.ModuleList):
+    def forward(self, x):
+        for b in self:
+            x = b(x)
+        return x
+
+
+class DinoVisionTransformer(nn.Module):
+    def __init__(
+        self,
+        img_size=224,
+        patch_size=16,
+        in_chans=3,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4.0,
+        qkv_bias=True,
+        ffn_bias=True,
+        proj_bias=True,
+        drop_path_rate=0.0,
+        drop_path_uniform=False,
+        init_values=None,  # for layerscale: None or 0 => no layerscale
+        embed_layer=PatchEmbed,
+        act_layer=nn.GELU,
+        block_fn=Block,
+        ffn_layer="mlp",
+        block_chunks=1,
+    ):
+        """
+        Args:
+            img_size (int, tuple): input image size
+            patch_size (int, tuple): patch size
+            in_chans (int): number of input channels
+            embed_dim (int): embedding dimension
+            depth (int): depth of transformer
+            num_heads (int): number of attention heads
+            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+            qkv_bias (bool): enable bias for qkv if True
+            proj_bias (bool): enable bias for proj in attn if True
+            ffn_bias (bool): enable bias for ffn if True
+            drop_path_rate (float): stochastic depth rate
+            drop_path_uniform (bool): apply uniform drop rate across blocks
+            weight_init (str): weight init scheme
+            init_values (float): layer-scale init values
+            embed_layer (nn.Module): patch embedding layer
+            act_layer (nn.Module): MLP activation layer
+            block_fn (nn.Module): transformer block class
+            ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+            block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+        """
+        super().__init__()
+        norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
+        self.num_tokens = 1
+        self.n_blocks = depth
+        self.num_heads = num_heads
+        self.patch_size = patch_size
+
+        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+        num_patches = self.patch_embed.num_patches
+
+        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
+
+        if drop_path_uniform is True:
+            dpr = [drop_path_rate] * depth
+        else:
+            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
+
+        if ffn_layer == "mlp":
+            ffn_layer = Mlp
+        elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+            ffn_layer = SwiGLUFFNFused
+        elif ffn_layer == "identity":
+
+            def f(*args, **kwargs):
+                return nn.Identity()
+
+            ffn_layer = f
+        else:
+            raise NotImplementedError
+
+        blocks_list = [
+            block_fn(
+                dim=embed_dim,
+                num_heads=num_heads,
+                mlp_ratio=mlp_ratio,
+                qkv_bias=qkv_bias,
+                proj_bias=proj_bias,
+                ffn_bias=ffn_bias,
+                drop_path=dpr[i],
+                norm_layer=norm_layer,
+                act_layer=act_layer,
+                ffn_layer=ffn_layer,
+                init_values=init_values,
+            )
+            for i in range(depth)
+        ]
+        if block_chunks > 0:
+            self.chunked_blocks = True
+            chunked_blocks = []
+            chunksize = depth // block_chunks
+            for i in range(0, depth, chunksize):
+                # this is to keep the block index consistent if we chunk the block list
+                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
+            self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
+        else:
+            self.chunked_blocks = False
+            self.blocks = nn.ModuleList(blocks_list)
+
+        self.norm = norm_layer(embed_dim)
+        self.head = nn.Identity()
+
+        self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+
+        self.init_weights()
+        for param in self.parameters():
+            param.requires_grad = False
+    
+    @property
+    def device(self):
+        return self.cls_token.device
+
+    def init_weights(self):
+        trunc_normal_(self.pos_embed, std=0.02)
+        nn.init.normal_(self.cls_token, std=1e-6)
+        named_apply(init_weights_vit_timm, self)
+
+    def interpolate_pos_encoding(self, x, w, h):
+        previous_dtype = x.dtype
+        npatch = x.shape[1] - 1
+        N = self.pos_embed.shape[1] - 1
+        if npatch == N and w == h:
+            return self.pos_embed
+        pos_embed = self.pos_embed.float()
+        class_pos_embed = pos_embed[:, 0]
+        patch_pos_embed = pos_embed[:, 1:]
+        dim = x.shape[-1]
+        w0 = w // self.patch_size
+        h0 = h // self.patch_size
+        # we add a small number to avoid floating point error in the interpolation
+        # see discussion at https://github.com/facebookresearch/dino/issues/8
+        w0, h0 = w0 + 0.1, h0 + 0.1
+
+        patch_pos_embed = nn.functional.interpolate(
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
+            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
+            mode="bicubic",
+        )
+
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
+        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
+
+    def prepare_tokens_with_masks(self, x, masks=None):
+        B, nc, w, h = x.shape
+        x = self.patch_embed(x)
+        if masks is not None:
+            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+        x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+        x = x + self.interpolate_pos_encoding(x, w, h)
+
+        return x
+
+    def forward_features_list(self, x_list, masks_list):
+        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
+        for blk in self.blocks:
+            x = blk(x)
+
+        all_x = x
+        output = []
+        for x, masks in zip(all_x, masks_list):
+            x_norm = self.norm(x)
+            output.append(
+                {
+                    "x_norm_clstoken": x_norm[:, 0],
+                    "x_norm_patchtokens": x_norm[:, 1:],
+                    "x_prenorm": x,
+                    "masks": masks,
+                }
+            )
+        return output
+
+    def forward_features(self, x, masks=None):
+        if isinstance(x, list):
+            return self.forward_features_list(x, masks)
+
+        x = self.prepare_tokens_with_masks(x, masks)
+
+        for blk in self.blocks:
+            x = blk(x)
+
+        x_norm = self.norm(x)
+        return {
+            "x_norm_clstoken": x_norm[:, 0],
+            "x_norm_patchtokens": x_norm[:, 1:],
+            "x_prenorm": x,
+            "masks": masks,
+        }
+
+    def _get_intermediate_layers_not_chunked(self, x, n=1):
+        x = self.prepare_tokens_with_masks(x)
+        # If n is an int, take the n last blocks. If it's a list, take them
+        output, total_block_len = [], len(self.blocks)
+        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+        for i, blk in enumerate(self.blocks):
+            x = blk(x)
+            if i in blocks_to_take:
+                output.append(x)
+        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+        return output
+
+    def _get_intermediate_layers_chunked(self, x, n=1):
+        x = self.prepare_tokens_with_masks(x)
+        output, i, total_block_len = [], 0, len(self.blocks[-1])
+        # If n is an int, take the n last blocks. If it's a list, take them
+        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
+        for block_chunk in self.blocks:
+            for blk in block_chunk[i:]:  # Passing the nn.Identity()
+                x = blk(x)
+                if i in blocks_to_take:
+                    output.append(x)
+                i += 1
+        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+        return output
+
+    def get_intermediate_layers(
+        self,
+        x: torch.Tensor,
+        n: Union[int, Sequence] = 1,  # Layers or n last layers to take
+        reshape: bool = False,
+        return_class_token: bool = False,
+        norm=True,
+    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
+        if self.chunked_blocks:
+            outputs = self._get_intermediate_layers_chunked(x, n)
+        else:
+            outputs = self._get_intermediate_layers_not_chunked(x, n)
+        if norm:
+            outputs = [self.norm(out) for out in outputs]
+        class_tokens = [out[:, 0] for out in outputs]
+        outputs = [out[:, 1:] for out in outputs]
+        if reshape:
+            B, _, w, h = x.shape
+            outputs = [
+                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
+                for out in outputs
+            ]
+        if return_class_token:
+            return tuple(zip(outputs, class_tokens))
+        return tuple(outputs)
+
+    def forward(self, *args, is_training=False, **kwargs):
+        ret = self.forward_features(*args, **kwargs)
+        if is_training:
+            return ret
+        else:
+            return self.head(ret["x_norm_clstoken"])
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+    """ViT weight initialization, original timm impl (for reproducibility)"""
+    if isinstance(module, nn.Linear):
+        trunc_normal_(module.weight, std=0.02)
+        if module.bias is not None:
+            nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, **kwargs):
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=384,
+        depth=12,
+        num_heads=6,
+        mlp_ratio=4,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+def vit_base(patch_size=16, **kwargs):
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=768,
+        depth=12,
+        num_heads=12,
+        mlp_ratio=4,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+def vit_large(patch_size=16, **kwargs):
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=1024,
+        depth=24,
+        num_heads=16,
+        mlp_ratio=4,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+    """
+    Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+    """
+    model = DinoVisionTransformer(
+        patch_size=patch_size,
+        embed_dim=1536,
+        depth=40,
+        num_heads=24,
+        mlp_ratio=4,
+        block_fn=partial(Block, attn_class=MemEffAttention),
+        **kwargs,
+    )
+    return model

+ 12 - 0
roma/models/transformer/layers/__init__.py

@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
+from .block import NestedTensorBlock
+from .attention import MemEffAttention

+ 81 - 0
roma/models/transformer/layers/attention.py

@@ -0,0 +1,81 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+from torch import Tensor
+from torch import nn
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+    from xformers.ops import memory_efficient_attention, unbind, fmha
+
+    XFORMERS_AVAILABLE = True
+except ImportError:
+    logger.warning("xFormers not available")
+    XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int = 8,
+        qkv_bias: bool = False,
+        proj_bias: bool = True,
+        attn_drop: float = 0.0,
+        proj_drop: float = 0.0,
+    ) -> None:
+        super().__init__()
+        self.num_heads = num_heads
+        head_dim = dim // num_heads
+        self.scale = head_dim**-0.5
+
+        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+        self.attn_drop = nn.Dropout(attn_drop)
+        self.proj = nn.Linear(dim, dim, bias=proj_bias)
+        self.proj_drop = nn.Dropout(proj_drop)
+
+    def forward(self, x: Tensor) -> Tensor:
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+
+        q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+        attn = q @ k.transpose(-2, -1)
+
+        attn = attn.softmax(dim=-1)
+        attn = self.attn_drop(attn)
+
+        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x
+
+
+class MemEffAttention(Attention):
+    def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+        if not XFORMERS_AVAILABLE:
+            assert attn_bias is None, "xFormers is required for nested tensors usage"
+            return super().forward(x)
+
+        B, N, C = x.shape
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+        q, k, v = unbind(qkv, 2)
+
+        x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+        x = x.reshape([B, N, C])
+
+        x = self.proj(x)
+        x = self.proj_drop(x)
+        return x

+ 252 - 0
roma/models/transformer/layers/block.py

@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Callable, List, Any, Tuple, Dict
+
+import torch
+from torch import nn, Tensor
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+    from xformers.ops import fmha
+    from xformers.ops import scaled_index_add, index_select_cat
+
+    XFORMERS_AVAILABLE = True
+except ImportError:
+    logger.warning("xFormers not available")
+    XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        num_heads: int,
+        mlp_ratio: float = 4.0,
+        qkv_bias: bool = False,
+        proj_bias: bool = True,
+        ffn_bias: bool = True,
+        drop: float = 0.0,
+        attn_drop: float = 0.0,
+        init_values=None,
+        drop_path: float = 0.0,
+        act_layer: Callable[..., nn.Module] = nn.GELU,
+        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+        attn_class: Callable[..., nn.Module] = Attention,
+        ffn_layer: Callable[..., nn.Module] = Mlp,
+    ) -> None:
+        super().__init__()
+        # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
+        self.norm1 = norm_layer(dim)
+        self.attn = attn_class(
+            dim,
+            num_heads=num_heads,
+            qkv_bias=qkv_bias,
+            proj_bias=proj_bias,
+            attn_drop=attn_drop,
+            proj_drop=drop,
+        )
+        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+        self.norm2 = norm_layer(dim)
+        mlp_hidden_dim = int(dim * mlp_ratio)
+        self.mlp = ffn_layer(
+            in_features=dim,
+            hidden_features=mlp_hidden_dim,
+            act_layer=act_layer,
+            drop=drop,
+            bias=ffn_bias,
+        )
+        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+        self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+        self.sample_drop_ratio = drop_path
+
+    def forward(self, x: Tensor) -> Tensor:
+        def attn_residual_func(x: Tensor) -> Tensor:
+            return self.ls1(self.attn(self.norm1(x)))
+
+        def ffn_residual_func(x: Tensor) -> Tensor:
+            return self.ls2(self.mlp(self.norm2(x)))
+
+        if self.training and self.sample_drop_ratio > 0.1:
+            # the overhead is compensated only for a drop path rate larger than 0.1
+            x = drop_add_residual_stochastic_depth(
+                x,
+                residual_func=attn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+            )
+            x = drop_add_residual_stochastic_depth(
+                x,
+                residual_func=ffn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+            )
+        elif self.training and self.sample_drop_ratio > 0.0:
+            x = x + self.drop_path1(attn_residual_func(x))
+            x = x + self.drop_path1(ffn_residual_func(x))  # FIXME: drop_path2
+        else:
+            x = x + attn_residual_func(x)
+            x = x + ffn_residual_func(x)
+        return x
+
+
+def drop_add_residual_stochastic_depth(
+    x: Tensor,
+    residual_func: Callable[[Tensor], Tensor],
+    sample_drop_ratio: float = 0.0,
+) -> Tensor:
+    # 1) extract subset using permutation
+    b, n, d = x.shape
+    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+    x_subset = x[brange]
+
+    # 2) apply residual_func to get residual
+    residual = residual_func(x_subset)
+
+    x_flat = x.flatten(1)
+    residual = residual.flatten(1)
+
+    residual_scale_factor = b / sample_subset_size
+
+    # 3) add the residual
+    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+    return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+    b, n, d = x.shape
+    sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+    brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+    residual_scale_factor = b / sample_subset_size
+    return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+    if scaling_vector is None:
+        x_flat = x.flatten(1)
+        residual = residual.flatten(1)
+        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
+    else:
+        x_plus_residual = scaled_index_add(
+            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
+        )
+    return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+    """
+    this will perform the index select, cat the tensors, and provide the attn_bias from cache
+    """
+    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
+    all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+    if all_shapes not in attn_bias_cache.keys():
+        seqlens = []
+        for b, x in zip(batch_sizes, x_list):
+            for _ in range(b):
+                seqlens.append(x.shape[1])
+        attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+        attn_bias._batch_sizes = batch_sizes
+        attn_bias_cache[all_shapes] = attn_bias
+
+    if branges is not None:
+        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
+    else:
+        tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+        cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+    return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+    x_list: List[Tensor],
+    residual_func: Callable[[Tensor, Any], Tensor],
+    sample_drop_ratio: float = 0.0,
+    scaling_vector=None,
+) -> Tensor:
+    # 1) generate random set of indices for dropping samples in the batch
+    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
+    branges = [s[0] for s in branges_scales]
+    residual_scale_factors = [s[1] for s in branges_scales]
+
+    # 2) get attention bias and index+concat the tensors
+    attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+    # 3) apply residual_func to get residual, and split the result
+    residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
+
+    outputs = []
+    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
+    return outputs
+
+
+class NestedTensorBlock(Block):
+    def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
+        """
+        x_list contains a list of tensors to nest together and run
+        """
+        assert isinstance(self.attn, MemEffAttention)
+
+        if self.training and self.sample_drop_ratio > 0.0:
+
+            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+                return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+                return self.mlp(self.norm2(x))
+
+            x_list = drop_add_residual_stochastic_depth_list(
+                x_list,
+                residual_func=attn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
+            )
+            x_list = drop_add_residual_stochastic_depth_list(
+                x_list,
+                residual_func=ffn_residual_func,
+                sample_drop_ratio=self.sample_drop_ratio,
+                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
+            )
+            return x_list
+        else:
+
+            def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+                return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+            def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
+                return self.ls2(self.mlp(self.norm2(x)))
+
+            attn_bias, x = get_attn_bias_and_cat(x_list)
+            x = x + attn_residual_func(x, attn_bias=attn_bias)
+            x = x + ffn_residual_func(x)
+            return attn_bias.split(x)
+
+    def forward(self, x_or_x_list):
+        if isinstance(x_or_x_list, Tensor):
+            return super().forward(x_or_x_list)
+        elif isinstance(x_or_x_list, list):
+            assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
+            return self.forward_nested(x_or_x_list)
+        else:
+            raise AssertionError

+ 59 - 0
roma/models/transformer/layers/dino_head.py

@@ -0,0 +1,59 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+    def __init__(
+        self,
+        in_dim,
+        out_dim,
+        use_bn=False,
+        nlayers=3,
+        hidden_dim=2048,
+        bottleneck_dim=256,
+        mlp_bias=True,
+    ):
+        super().__init__()
+        nlayers = max(nlayers, 1)
+        self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
+        self.apply(self._init_weights)
+        self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+        self.last_layer.weight_g.data.fill_(1)
+
+    def _init_weights(self, m):
+        if isinstance(m, nn.Linear):
+            trunc_normal_(m.weight, std=0.02)
+            if isinstance(m, nn.Linear) and m.bias is not None:
+                nn.init.constant_(m.bias, 0)
+
+    def forward(self, x):
+        x = self.mlp(x)
+        eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+        x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+        x = self.last_layer(x)
+        return x
+
+
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
+    if nlayers == 1:
+        return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+    else:
+        layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+        if use_bn:
+            layers.append(nn.BatchNorm1d(hidden_dim))
+        layers.append(nn.GELU())
+        for _ in range(nlayers - 2):
+            layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+            if use_bn:
+                layers.append(nn.BatchNorm1d(hidden_dim))
+            layers.append(nn.GELU())
+        layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+        return nn.Sequential(*layers)

+ 35 - 0
roma/models/transformer/layers/drop_path.py

@@ -0,0 +1,35 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+from torch import nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+    if drop_prob == 0.0 or not training:
+        return x
+    keep_prob = 1 - drop_prob
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
+    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+    if keep_prob > 0.0:
+        random_tensor.div_(keep_prob)
+    output = x * random_tensor
+    return output
+
+
+class DropPath(nn.Module):
+    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+    def __init__(self, drop_prob=None):
+        super(DropPath, self).__init__()
+        self.drop_prob = drop_prob
+
+    def forward(self, x):
+        return drop_path(x, self.drop_prob, self.training)

+ 28 - 0
roma/models/transformer/layers/layer_scale.py

@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+from torch import Tensor
+from torch import nn
+
+
+class LayerScale(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        init_values: Union[float, Tensor] = 1e-5,
+        inplace: bool = False,
+    ) -> None:
+        super().__init__()
+        self.inplace = inplace
+        self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+    def forward(self, x: Tensor) -> Tensor:
+        return x.mul_(self.gamma) if self.inplace else x * self.gamma

+ 41 - 0
roma/models/transformer/layers/mlp.py

@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: Optional[int] = None,
+        out_features: Optional[int] = None,
+        act_layer: Callable[..., nn.Module] = nn.GELU,
+        drop: float = 0.0,
+        bias: bool = True,
+    ) -> None:
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+        self.act = act_layer()
+        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+        self.drop = nn.Dropout(drop)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = self.fc1(x)
+        x = self.act(x)
+        x = self.drop(x)
+        x = self.fc2(x)
+        x = self.drop(x)
+        return x

+ 89 - 0
roma/models/transformer/layers/patch_embed.py

@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+#   https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+#   https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+from torch import Tensor
+import torch.nn as nn
+
+
+def make_2tuple(x):
+    if isinstance(x, tuple):
+        assert len(x) == 2
+        return x
+
+    assert isinstance(x, int)
+    return (x, x)
+
+
+class PatchEmbed(nn.Module):
+    """
+    2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+    Args:
+        img_size: Image size.
+        patch_size: Patch token size.
+        in_chans: Number of input image channels.
+        embed_dim: Number of linear projection output channels.
+        norm_layer: Normalization layer.
+    """
+
+    def __init__(
+        self,
+        img_size: Union[int, Tuple[int, int]] = 224,
+        patch_size: Union[int, Tuple[int, int]] = 16,
+        in_chans: int = 3,
+        embed_dim: int = 768,
+        norm_layer: Optional[Callable] = None,
+        flatten_embedding: bool = True,
+    ) -> None:
+        super().__init__()
+
+        image_HW = make_2tuple(img_size)
+        patch_HW = make_2tuple(patch_size)
+        patch_grid_size = (
+            image_HW[0] // patch_HW[0],
+            image_HW[1] // patch_HW[1],
+        )
+
+        self.img_size = image_HW
+        self.patch_size = patch_HW
+        self.patches_resolution = patch_grid_size
+        self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+        self.in_chans = in_chans
+        self.embed_dim = embed_dim
+
+        self.flatten_embedding = flatten_embedding
+
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
+        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+    def forward(self, x: Tensor) -> Tensor:
+        _, _, H, W = x.shape
+        patch_H, patch_W = self.patch_size
+
+        assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+        assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+        x = self.proj(x)  # B C H W
+        H, W = x.size(2), x.size(3)
+        x = x.flatten(2).transpose(1, 2)  # B HW C
+        x = self.norm(x)
+        if not self.flatten_embedding:
+            x = x.reshape(-1, H, W, self.embed_dim)  # B H W C
+        return x
+
+    def flops(self) -> float:
+        Ho, Wo = self.patches_resolution
+        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+        if self.norm is not None:
+            flops += Ho * Wo * self.embed_dim
+        return flops

+ 63 - 0
roma/models/transformer/layers/swiglu_ffn.py

@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+import torch.nn.functional as F
+
+
+class SwiGLUFFN(nn.Module):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: Optional[int] = None,
+        out_features: Optional[int] = None,
+        act_layer: Callable[..., nn.Module] = None,
+        drop: float = 0.0,
+        bias: bool = True,
+    ) -> None:
+        super().__init__()
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+        self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+    def forward(self, x: Tensor) -> Tensor:
+        x12 = self.w12(x)
+        x1, x2 = x12.chunk(2, dim=-1)
+        hidden = F.silu(x1) * x2
+        return self.w3(hidden)
+
+
+try:
+    from xformers.ops import SwiGLU
+
+    XFORMERS_AVAILABLE = True
+except ImportError:
+    SwiGLU = SwiGLUFFN
+    XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+    def __init__(
+        self,
+        in_features: int,
+        hidden_features: Optional[int] = None,
+        out_features: Optional[int] = None,
+        act_layer: Callable[..., nn.Module] = None,
+        drop: float = 0.0,
+        bias: bool = True,
+    ) -> None:
+        out_features = out_features or in_features
+        hidden_features = hidden_features or in_features
+        hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+        super().__init__(
+            in_features=in_features,
+            hidden_features=hidden_features,
+            out_features=out_features,
+            bias=bias,
+        )

+ 1 - 0
roma/train/__init__.py

@@ -0,0 +1 @@
+from .train import train_k_epochs

+ 102 - 0
roma/train/train.py

@@ -0,0 +1,102 @@
+from tqdm import tqdm
+from roma.utils.utils import to_cuda
+import roma
+import torch
+import wandb
+
+def log_param_statistics(named_parameters, norm_type = 2):
+    named_parameters = list(named_parameters)
+    grads = [p.grad for n, p in named_parameters if p.grad is not None]
+    weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
+    names = [n for n,p in named_parameters if p.grad is not None]
+    param_norm = torch.stack(weight_norms).norm(p=norm_type)
+    device = grads[0].device
+    grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
+    nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
+    nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
+    total_grad_norm = torch.norm(grad_norms, norm_type)
+    if torch.any(nans_or_infs):
+        print(f"These params have nan or inf grads: {nan_inf_names}")
+    wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP)
+    wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP)
+
+def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
+    optimizer.zero_grad()
+    out = model(train_batch)
+    l = objective(out, train_batch)
+    grad_scaler.scale(l).backward()
+    grad_scaler.unscale_(optimizer)
+    log_param_statistics(model.named_parameters())
+    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
+    grad_scaler.step(optimizer)
+    grad_scaler.update()
+    wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP)
+    if grad_scaler._scale < 1.:
+        grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
+    roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
+    return {"train_out": out, "train_loss": l.item()}
+
+
+def train_k_steps(
+    n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None,
+):
+    for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
+        batch = next(dataloader)
+        model.train(True)
+        batch = to_cuda(batch)
+        train_step(
+            train_batch=batch,
+            model=model,
+            objective=objective,
+            optimizer=optimizer,
+            lr_scheduler=lr_scheduler,
+            grad_scaler=grad_scaler,
+            n=n,
+            grad_clip_norm = grad_clip_norm,
+        )
+        if ema_model is not None:
+            ema_model.update()
+        if warmup is not None:
+            with warmup.dampening():
+                lr_scheduler.step()
+        else:
+            lr_scheduler.step()
+        [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
+
+
+def train_epoch(
+    dataloader=None,
+    model=None,
+    objective=None,
+    optimizer=None,
+    lr_scheduler=None,
+    epoch=None,
+):
+    model.train(True)
+    print(f"At epoch {epoch}")
+    for batch in tqdm(dataloader, mininterval=5.0):
+        batch = to_cuda(batch)
+        train_step(
+            train_batch=batch, model=model, objective=objective, optimizer=optimizer
+        )
+    lr_scheduler.step()
+    return {
+        "model": model,
+        "optimizer": optimizer,
+        "lr_scheduler": lr_scheduler,
+        "epoch": epoch,
+    }
+
+
+def train_k_epochs(
+    start_epoch, end_epoch, dataloader, model, objective, optimizer, lr_scheduler
+):
+    for epoch in range(start_epoch, end_epoch + 1):
+        train_epoch(
+            dataloader=dataloader,
+            model=model,
+            objective=objective,
+            optimizer=optimizer,
+            lr_scheduler=lr_scheduler,
+            epoch=epoch,
+        )

+ 16 - 0
roma/utils/__init__.py

@@ -0,0 +1,16 @@
+from .utils import (
+    pose_auc,
+    get_pose,
+    compute_relative_pose,
+    compute_pose_error,
+    estimate_pose,
+    estimate_pose_uncalibrated,
+    rotate_intrinsic,
+    get_tuple_transform_ops,
+    get_depth_tuple_transform_ops,
+    warp_kpts,
+    numpy_to_pil,
+    tensor_to_pil,
+    recover_pose,
+    signed_left_to_right_epipolar_distance,
+)

+ 8 - 0
roma/utils/kde.py

@@ -0,0 +1,8 @@
+import torch
+
+def kde(x, std = 0.1):
+    # use a gaussian kernel to estimate density
+    x = x.half() # Do it in half precision
+    scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+    density = scores.sum(dim=-1)
+    return density

+ 47 - 0
roma/utils/local_correlation.py

@@ -0,0 +1,47 @@
+import torch
+import torch.nn.functional as F
+
+def local_correlation(
+    feature0,
+    feature1,
+    local_radius,
+    padding_mode="zeros",
+    flow = None,
+    sample_mode = "bilinear",
+):
+    r = local_radius
+    K = (2*r+1)**2
+    B, c, h, w = feature0.size()
+    feature0 = feature0.half()
+    feature1 = feature1.half()
+    corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
+    if flow is None:
+        # If flow is None, assume feature0 and feature1 are aligned
+        coords = torch.meshgrid(
+                (
+                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device="cuda"),
+                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device="cuda"),
+                ))
+        coords = torch.stack((coords[1], coords[0]), dim=-1)[
+            None
+        ].expand(B, h, w, 2)
+    else:
+        coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+    local_window = torch.meshgrid(
+                (
+                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device="cuda"),
+                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device="cuda"),
+                ))
+    local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
+            None
+        ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
+    for _ in range(B):
+        with torch.no_grad():
+            local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2).half()
+            window_feature = F.grid_sample(
+                feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
+            )
+            window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
+        corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
+    torch.cuda.empty_cache()
+    return corr

+ 118 - 0
roma/utils/transforms.py

@@ -0,0 +1,118 @@
+from typing import Dict
+import numpy as np
+import torch
+import kornia.augmentation as K
+from kornia.geometry.transform import warp_perspective
+
+# Adapted from Kornia
+class GeometricSequential:
+    def __init__(self, *transforms, align_corners=True) -> None:
+        self.transforms = transforms
+        self.align_corners = align_corners
+
+    def __call__(self, x, mode="bilinear"):
+        b, c, h, w = x.shape
+        M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
+        for t in self.transforms:
+            if np.random.rand() < t.p:
+                M = M.matmul(
+                    t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
+                )
+        return (
+            warp_perspective(
+                x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
+            ),
+            M,
+        )
+
+    def apply_transform(self, x, M, mode="bilinear"):
+        b, c, h, w = x.shape
+        return warp_perspective(
+            x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
+        )
+
+
+class RandomPerspective(K.RandomPerspective):
+    def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
+        distortion_scale = torch.as_tensor(
+            self.distortion_scale, device=self._device, dtype=self._dtype
+        )
+        return self.random_perspective_generator(
+            batch_shape[0],
+            batch_shape[-2],
+            batch_shape[-1],
+            distortion_scale,
+            self.same_on_batch,
+            self.device,
+            self.dtype,
+        )
+
+    def random_perspective_generator(
+        self,
+        batch_size: int,
+        height: int,
+        width: int,
+        distortion_scale: torch.Tensor,
+        same_on_batch: bool = False,
+        device: torch.device = torch.device("cpu"),
+        dtype: torch.dtype = torch.float32,
+    ) -> Dict[str, torch.Tensor]:
+        r"""Get parameters for ``perspective`` for a random perspective transform.
+
+        Args:
+            batch_size (int): the tensor batch size.
+            height (int) : height of the image.
+            width (int): width of the image.
+            distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
+            same_on_batch (bool): apply the same transformation across the batch. Default: False.
+            device (torch.device): the device on which the random numbers will be generated. Default: cpu.
+            dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
+
+        Returns:
+            params Dict[str, torch.Tensor]: parameters to be passed for transformation.
+                - start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
+                - end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
+
+        Note:
+            The generated random numbers are not reproducible across different devices and dtypes.
+        """
+        if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
+            raise AssertionError(
+                f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
+            )
+        if not (
+            type(height) is int and height > 0 and type(width) is int and width > 0
+        ):
+            raise AssertionError(
+                f"'height' and 'width' must be integers. Got {height}, {width}."
+            )
+
+        start_points: torch.Tensor = torch.tensor(
+            [[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
+            device=distortion_scale.device,
+            dtype=distortion_scale.dtype,
+        ).expand(batch_size, -1, -1)
+
+        # generate random offset not larger than half of the image
+        fx = distortion_scale * width / 2
+        fy = distortion_scale * height / 2
+
+        factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
+        offset = (torch.rand_like(start_points) - 0.5) * 2
+        end_points = start_points + factor * offset
+
+        return dict(start_points=start_points, end_points=end_points)
+
+
+
+class RandomErasing:
+    def __init__(self, p = 0., scale = 0.) -> None:
+        self.p = p
+        self.scale = scale
+        self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
+    def __call__(self, image, depth):
+        if self.p > 0:
+            image = self.random_eraser(image)
+            depth = self.random_eraser(depth, params=self.random_eraser._params)
+        return image, depth
+        

+ 622 - 0
roma/utils/utils.py

@@ -0,0 +1,622 @@
+import warnings
+import numpy as np
+import cv2
+import math
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+import torch.nn.functional as F
+from PIL import Image
+import kornia
+
+def recover_pose(E, kpts0, kpts1, K0, K1, mask):
+    best_num_inliers = 0
+    K0inv = np.linalg.inv(K0[:2,:2])
+    K1inv = np.linalg.inv(K1[:2,:2])
+
+    kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
+    kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+
+    for _E in np.split(E, len(E) / 3):
+        n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+        if n > best_num_inliers:
+            best_num_inliers = n
+            ret = (R, t, mask.ravel() > 0)
+    return ret
+
+
+
+# Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
+# --- GEOMETRY ---
+def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+    if len(kpts0) < 5:
+        return None
+    K0inv = np.linalg.inv(K0[:2,:2])
+    K1inv = np.linalg.inv(K1[:2,:2])
+
+    kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
+    kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+    E, mask = cv2.findEssentialMat(
+        kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
+    )
+
+    ret = None
+    if E is not None:
+        best_num_inliers = 0
+
+        for _E in np.split(E, len(E) / 3):
+            n, R, t, _ = cv2.recoverPose(_E, kpts0, kpts1, np.eye(3), 1e9, mask=mask)
+            if n > best_num_inliers:
+                best_num_inliers = n
+                ret = (R, t, mask.ravel() > 0)
+    return ret
+
+def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
+    if len(kpts0) < 5:
+        return None
+    method = cv2.USAC_ACCURATE
+    F, mask = cv2.findFundamentalMat(
+        kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
+    )
+    E = K1.T@F@K0
+    ret = None
+    if E is not None:
+        best_num_inliers = 0
+        K0inv = np.linalg.inv(K0[:2,:2])
+        K1inv = np.linalg.inv(K1[:2,:2])
+
+        kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
+        kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+ 
+        for _E in np.split(E, len(E) / 3):
+            n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
+            if n > best_num_inliers:
+                best_num_inliers = n
+                ret = (R, t, mask.ravel() > 0)
+    return ret
+
+def unnormalize_coords(x_n,h,w):
+    x = torch.stack(
+        (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
+    )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+    return x
+
+
+def rotate_intrinsic(K, n):
+    base_rot = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+    rot = np.linalg.matrix_power(base_rot, n)
+    return rot @ K
+
+
+def rotate_pose_inplane(i_T_w, rot):
+    rotation_matrices = [
+        np.array(
+            [
+                [np.cos(r), -np.sin(r), 0.0, 0.0],
+                [np.sin(r), np.cos(r), 0.0, 0.0],
+                [0.0, 0.0, 1.0, 0.0],
+                [0.0, 0.0, 0.0, 1.0],
+            ],
+            dtype=np.float32,
+        )
+        for r in [np.deg2rad(d) for d in (0, 270, 180, 90)]
+    ]
+    return np.dot(rotation_matrices[rot], i_T_w)
+
+
+def scale_intrinsics(K, scales):
+    scales = np.diag([1.0 / scales[0], 1.0 / scales[1], 1.0])
+    return np.dot(scales, K)
+
+
+def to_homogeneous(points):
+    return np.concatenate([points, np.ones_like(points[:, :1])], axis=-1)
+
+
+def angle_error_mat(R1, R2):
+    cos = (np.trace(np.dot(R1.T, R2)) - 1) / 2
+    cos = np.clip(cos, -1.0, 1.0)  # numercial errors can make it out of bounds
+    return np.rad2deg(np.abs(np.arccos(cos)))
+
+
+def angle_error_vec(v1, v2):
+    n = np.linalg.norm(v1) * np.linalg.norm(v2)
+    return np.rad2deg(np.arccos(np.clip(np.dot(v1, v2) / n, -1.0, 1.0)))
+
+
+def compute_pose_error(T_0to1, R, t):
+    R_gt = T_0to1[:3, :3]
+    t_gt = T_0to1[:3, 3]
+    error_t = angle_error_vec(t.squeeze(), t_gt)
+    error_t = np.minimum(error_t, 180 - error_t)  # ambiguity of E estimation
+    error_R = angle_error_mat(R, R_gt)
+    return error_t, error_R
+
+
+def pose_auc(errors, thresholds):
+    sort_idx = np.argsort(errors)
+    errors = np.array(errors.copy())[sort_idx]
+    recall = (np.arange(len(errors)) + 1) / len(errors)
+    errors = np.r_[0.0, errors]
+    recall = np.r_[0.0, recall]
+    aucs = []
+    for t in thresholds:
+        last_index = np.searchsorted(errors, t)
+        r = np.r_[recall[:last_index], recall[last_index - 1]]
+        e = np.r_[errors[:last_index], t]
+        aucs.append(np.trapz(r, x=e) / t)
+    return aucs
+
+
+# From Patch2Pix https://github.com/GrumpyZhou/patch2pix
+def get_depth_tuple_transform_ops_nearest_exact(resize=None):
+    ops = []
+    if resize:
+        ops.append(TupleResizeNearestExact(resize))
+    return TupleCompose(ops)
+
+def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
+    ops = []
+    if resize:
+        ops.append(TupleResize(resize, mode=InterpolationMode.BILINEAR))
+    return TupleCompose(ops)
+
+
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
+    ops = []
+    if resize:
+        ops.append(TupleResize(resize))
+    ops.append(TupleToTensorScaled())
+    if normalize:
+        ops.append(
+            TupleNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+        )  # Imagenet mean/std
+    return TupleCompose(ops)
+
+class ToTensorScaled(object):
+    """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
+
+    def __call__(self, im):
+        if not isinstance(im, torch.Tensor):
+            im = np.array(im, dtype=np.float32).transpose((2, 0, 1))
+            im /= 255.0
+            return torch.from_numpy(im)
+        else:
+            return im
+
+    def __repr__(self):
+        return "ToTensorScaled(./255)"
+
+
+class TupleToTensorScaled(object):
+    def __init__(self):
+        self.to_tensor = ToTensorScaled()
+
+    def __call__(self, im_tuple):
+        return [self.to_tensor(im) for im in im_tuple]
+
+    def __repr__(self):
+        return "TupleToTensorScaled(./255)"
+
+
+class ToTensorUnscaled(object):
+    """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+    def __call__(self, im):
+        return torch.from_numpy(np.array(im, dtype=np.float32).transpose((2, 0, 1)))
+
+    def __repr__(self):
+        return "ToTensorUnscaled()"
+
+
+class TupleToTensorUnscaled(object):
+    """Convert a RGB PIL Image to a CHW ordered Tensor"""
+
+    def __init__(self):
+        self.to_tensor = ToTensorUnscaled()
+
+    def __call__(self, im_tuple):
+        return [self.to_tensor(im) for im in im_tuple]
+
+    def __repr__(self):
+        return "TupleToTensorUnscaled()"
+
+class TupleResizeNearestExact:
+    def __init__(self, size):
+        self.size = size
+    def __call__(self, im_tuple):
+        return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
+
+    def __repr__(self):
+        return "TupleResizeNearestExact(size={})".format(self.size)
+
+
+class TupleResize(object):
+    def __init__(self, size, mode=InterpolationMode.BICUBIC):
+        self.size = size
+        self.resize = transforms.Resize(size, mode)
+    def __call__(self, im_tuple):
+        return [self.resize(im) for im in im_tuple]
+
+    def __repr__(self):
+        return "TupleResize(size={})".format(self.size)
+    
+class Normalize:
+    def __call__(self,im):
+        mean = im.mean(dim=(1,2), keepdims=True)
+        std = im.std(dim=(1,2), keepdims=True)
+        return (im-mean)/std
+
+
+class TupleNormalize(object):
+    def __init__(self, mean, std):
+        self.mean = mean
+        self.std = std
+        self.normalize = transforms.Normalize(mean=mean, std=std)
+
+    def __call__(self, im_tuple):
+        c,h,w = im_tuple[0].shape
+        if c > 3:
+            warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
+        return [self.normalize(im[:3]) for im in im_tuple]
+
+    def __repr__(self):
+        return "TupleNormalize(mean={}, std={})".format(self.mean, self.std)
+
+
+class TupleCompose(object):
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, im_tuple):
+        for t in self.transforms:
+            im_tuple = t(im_tuple)
+        return im_tuple
+
+    def __repr__(self):
+        format_string = self.__class__.__name__ + "("
+        for t in self.transforms:
+            format_string += "\n"
+            format_string += "    {0}".format(t)
+        format_string += "\n)"
+        return format_string
+
+@torch.no_grad()
+def cls_to_flow(cls, deterministic_sampling = True):
+    B,C,H,W = cls.shape
+    device = cls.device
+    res = round(math.sqrt(C))
+    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+    G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+    if deterministic_sampling:
+        sampled_cls = cls.max(dim=1).indices
+    else:
+        sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
+    flow = G[sampled_cls]
+    return flow
+
+@torch.no_grad()
+def cls_to_flow_refine(cls):
+    B,C,H,W = cls.shape
+    device = cls.device
+    res = round(math.sqrt(C))
+    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+    G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
+    cls = cls.softmax(dim=1)
+    mode = cls.max(dim=1).indices
+    
+    index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
+    neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
+    flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
+    tot_prob = neighbours.sum(dim=1)  
+    flow = flow / tot_prob
+    return flow
+
+
+def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
+    
+    if H is None:
+        B,H,W = depth1.shape
+    else:
+        B = depth1.shape[0]
+    with torch.no_grad():
+        x1_n = torch.meshgrid(
+            *[
+                torch.linspace(
+                    -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
+                )
+                for n in (B, H, W)
+            ]
+        )
+        x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(B, H * W, 2)
+        mask, x2 = warp_kpts(
+            x1_n.double(),
+            depth1.double(),
+            depth2.double(),
+            T_1to2.double(),
+            K1.double(),
+            K2.double(),
+            depth_interpolation_mode = depth_interpolation_mode,
+            relative_depth_error_threshold = relative_depth_error_threshold,
+        )
+        prob = mask.float().reshape(B, H, W)
+        x2 = x2.reshape(B, H, W, 2)
+        return x2, prob
+
+@torch.no_grad()
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
+    """Warp kpts0 from I0 to I1 with depth, K and Rt
+    Also check covisibility and depth consistency.
+    Depth is consistent if relative error < 0.2 (hard-coded).
+    # https://github.com/zju3dv/LoFTR/blob/94e98b695be18acb43d5d3250f52226a8e36f839/src/loftr/utils/geometry.py adapted from here
+    Args:
+        kpts0 (torch.Tensor): [N, L, 2] - <x, y>, should be normalized in (-1,1)
+        depth0 (torch.Tensor): [N, H, W],
+        depth1 (torch.Tensor): [N, H, W],
+        T_0to1 (torch.Tensor): [N, 3, 4],
+        K0 (torch.Tensor): [N, 3, 3],
+        K1 (torch.Tensor): [N, 3, 3],
+    Returns:
+        calculable_mask (torch.Tensor): [N, L]
+        warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
+    """
+    (
+        n,
+        h,
+        w,
+    ) = depth0.shape
+    if depth_interpolation_mode == "combined":
+        # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
+        if smooth_mask:
+            raise NotImplementedError("Combined bilinear and NN warp not implemented")
+        valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
+                  smooth_mask = smooth_mask, 
+                  return_relative_depth_error = return_relative_depth_error, 
+                  depth_interpolation_mode = "bilinear",
+                  relative_depth_error_threshold = relative_depth_error_threshold)
+        valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
+                  smooth_mask = smooth_mask, 
+                  return_relative_depth_error = return_relative_depth_error, 
+                  depth_interpolation_mode = "nearest-exact",
+                  relative_depth_error_threshold = relative_depth_error_threshold)
+        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) 
+        warp = warp_bilinear.clone()
+        warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
+        valid = valid_bilinear | valid_nearest
+        return valid, warp
+        
+        
+    kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
+        :, 0, :, 0
+    ]
+    kpts0 = torch.stack(
+        (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
+    )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
+    # Sample depth, get calculable_mask on depth != 0
+    nonzero_mask = kpts0_depth != 0
+
+    # Unproject
+    kpts0_h = (
+        torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1)
+        * kpts0_depth[..., None]
+    )  # (N, L, 3)
+    kpts0_n = K0.inverse() @ kpts0_h.transpose(2, 1)  # (N, 3, L)
+    kpts0_cam = kpts0_n
+
+    # Rigid Transform
+    w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]]  # (N, 3, L)
+    w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
+
+    # Project
+    w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1)  # (N, L, 3)
+    w_kpts0 = w_kpts0_h[:, :, :2] / (
+        w_kpts0_h[:, :, [2]] + 1e-4
+    )  # (N, L, 2), +1e-4 to avoid zero depth
+
+    # Covisible Check
+    h, w = depth1.shape[1:3]
+    covisible_mask = (
+        (w_kpts0[:, :, 0] > 0)
+        * (w_kpts0[:, :, 0] < w - 1)
+        * (w_kpts0[:, :, 1] > 0)
+        * (w_kpts0[:, :, 1] < h - 1)
+    )
+    w_kpts0 = torch.stack(
+        (2 * w_kpts0[..., 0] / w - 1, 2 * w_kpts0[..., 1] / h - 1), dim=-1
+    )  # from [0.5,h-0.5] -> [-1+1/h, 1-1/h]
+    # w_kpts0[~covisible_mask, :] = -5 # xd
+
+    w_kpts0_depth = F.grid_sample(
+        depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
+    )[:, 0, :, 0]
+    
+    relative_depth_error = (
+        (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
+    ).abs()
+    if not smooth_mask:
+        consistent_mask = relative_depth_error < relative_depth_error_threshold
+    else:
+        consistent_mask = (-relative_depth_error/smooth_mask).exp()
+    valid_mask = nonzero_mask * covisible_mask * consistent_mask
+    if return_relative_depth_error:
+        return relative_depth_error, w_kpts0
+    else:
+        return valid_mask, w_kpts0
+
+imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
+imagenet_std = torch.tensor([0.229, 0.224, 0.225])
+
+
+def numpy_to_pil(x: np.ndarray):
+    """
+    Args:
+        x: Assumed to be of shape (h,w,c)
+    """
+    if isinstance(x, torch.Tensor):
+        x = x.detach().cpu().numpy()
+    if x.max() <= 1.01:
+        x *= 255
+    x = x.astype(np.uint8)
+    return Image.fromarray(x)
+
+
+def tensor_to_pil(x, unnormalize=False):
+    if unnormalize:
+        x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
+    x = x.detach().permute(1, 2, 0).cpu().numpy()
+    x = np.clip(x, 0.0, 1.0)
+    return numpy_to_pil(x)
+
+
+def to_cuda(batch):
+    for key, value in batch.items():
+        if isinstance(value, torch.Tensor):
+            batch[key] = value.cuda()
+    return batch
+
+
+def to_cpu(batch):
+    for key, value in batch.items():
+        if isinstance(value, torch.Tensor):
+            batch[key] = value.cpu()
+    return batch
+
+
+def get_pose(calib):
+    w, h = np.array(calib["imsize"])[0]
+    return np.array(calib["K"]), np.array(calib["R"]), np.array(calib["T"]).T, h, w
+
+
+def compute_relative_pose(R1, t1, R2, t2):
+    rots = R2 @ (R1.T)
+    trans = -rots @ t1 + t2
+    return rots, trans
+
+@torch.no_grad()
+def reset_opt(opt):
+    for group in opt.param_groups:
+        for p in group['params']:
+            if p.requires_grad:
+                state = opt.state[p]
+                # State initialization
+
+                # Exponential moving average of gradient values
+                state['exp_avg'] = torch.zeros_like(p)
+                # Exponential moving average of squared gradient values
+                state['exp_avg_sq'] = torch.zeros_like(p)
+                # Exponential moving average of gradient difference
+                state['exp_avg_diff'] = torch.zeros_like(p)
+
+
+def flow_to_pixel_coords(flow, h1, w1):
+    flow = (
+        torch.stack(
+            (
+                w1 * (flow[..., 0] + 1) / 2,
+                h1 * (flow[..., 1] + 1) / 2,
+            ),
+            axis=-1,
+        )
+    )
+    return flow
+
+def flow_to_normalized_coords(flow, h1, w1):
+    flow = (
+        torch.stack(
+            (
+                2 * (flow[..., 0]) / w1 - 1,
+                2 * (flow[..., 1]) / h1 - 1,
+            ),
+            axis=-1,
+        )
+    )
+    return flow
+
+
+def warp_to_pixel_coords(warp, h1, w1, h2, w2):
+    warp1 = warp[..., :2]
+    warp1 = (
+        torch.stack(
+            (
+                w1 * (warp1[..., 0] + 1) / 2,
+                h1 * (warp1[..., 1] + 1) / 2,
+            ),
+            axis=-1,
+        )
+    )
+    warp2 = warp[..., 2:]
+    warp2 = (
+        torch.stack(
+            (
+                w2 * (warp2[..., 0] + 1) / 2,
+                h2 * (warp2[..., 1] + 1) / 2,
+            ),
+            axis=-1,
+        )
+    )
+    return torch.cat((warp1,warp2), dim=-1)
+
+
+
+def signed_point_line_distance(point, line, eps: float = 1e-9):
+    r"""Return the distance from points to lines.
+
+    Args:
+       point: (possibly homogeneous) points :math:`(*, N, 2 or 3)`.
+       line: lines coefficients :math:`(a, b, c)` with shape :math:`(*, N, 3)`, where :math:`ax + by + c = 0`.
+       eps: Small constant for safe sqrt.
+
+    Returns:
+        the computed distance with shape :math:`(*, N)`.
+    """
+
+    if not point.shape[-1] in (2, 3):
+        raise ValueError(f"pts must be a (*, 2 or 3) tensor. Got {point.shape}")
+
+    if not line.shape[-1] == 3:
+        raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
+
+    numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
+    denominator = line[..., :2].norm(dim=-1)
+
+    return numerator / (denominator + eps)
+
+
+def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
+    r"""Return one-sided epipolar distance for correspondences given the fundamental matrix.
+
+    This method measures the distance from points in the right images to the epilines
+    of the corresponding points in the left images as they reflect in the right images.
+
+    Args:
+       pts1: correspondences from the left images with shape
+         :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
+       pts2: correspondences from the right images with shape
+         :math:`(*, N, 2 or 3)`. If they are not homogeneous, converted automatically.
+       Fm: Fundamental matrices with shape :math:`(*, 3, 3)`. Called Fm to
+         avoid ambiguity with torch.nn.functional.
+
+    Returns:
+        the computed Symmetrical distance with shape :math:`(*, N)`.
+    """
+    import kornia
+    if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
+        raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
+
+    if pts1.shape[-1] == 2:
+        pts1 = kornia.geometry.convert_points_to_homogeneous(pts1)
+
+    F_t = Fm.transpose(dim0=-2, dim1=-1)
+    line1_in_2 = pts1 @ F_t
+
+    return signed_point_line_distance(pts2, line1_in_2)
+
+def get_grid(b, h, w, device):
+    grid = torch.meshgrid(
+        *[
+            torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
+            for n in (b, h, w)
+        ]
+    )
+    grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
+    return grid

+ 9 - 0
setup.py

@@ -0,0 +1,9 @@
+from setuptools import setup
+
+setup(
+    name="roma",
+    packages=["roma"],
+    version="0.0.1",
+    author="Johan Edstedt",
+    install_requires=open("requirements.txt", "r").read().split("\n"),
+)