from typing import Union from pathlib import Path import os import shutil import torch from .roma_models import roma_model,roma_model_pad, tiny_roma_v1_model weight_urls = { "romatch": { "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth", "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth", }, "tiny_roma_v1": { "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/tiny_roma_v1_outdoor.pth", }, "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", # hopefully this doesnt change :D } def _project_root() -> Path: # romatch/models/model_zoo/__init__.py -> repo root is 3 levels up return Path(__file__).resolve().parents[3] def _ensure_local_torch_home_with_optional_weights() -> Path: """ Make torch-hub caching point to a directory inside the current project. If the user places weight files directly under the project root (e.g. for offline use), we copy them into torch's expected hub/checkpoints directory so that subsequent torch.hub calls won't need to download the weights. """ torch_home_env = os.environ.get("ROMATCH_TORCH_HOME") torch_home = Path(torch_home_env) if torch_home_env else (_project_root() / "torch_home") checkpoints_dir = torch_home / "hub" / "checkpoints" checkpoints_dir.mkdir(parents=True, exist_ok=True) # If present, copy user-provided weight files next to the project to where torch.hub expects them. local_candidates = [ (_project_root() / "tiny_roma_v1_outdoor.pth", checkpoints_dir / "tiny_roma_v1_outdoor.pth"), (_project_root() / "xfeat.pt", checkpoints_dir / "xfeat.pt"), ] for src, dst in local_candidates: if dst.exists(): continue if src.exists(): shutil.copyfile(str(src), str(dst)) # Also copy cached torch-hub repository code into the project torch_home. # # Tiny RoMa's XFeat is loaded via torch.hub.load("verlab/accelerated_features", ...). # torch.hub may still hit the network unless the repo code directory is present. default_torch_home = Path.home() / ".cache" / "torch" repo_name = "verlab_accelerated_features_main" hub_src = default_torch_home / "hub" / repo_name hub_dst = torch_home / "hub" / repo_name if hub_dst.exists() is False and hub_src.exists() is True: shutil.copytree(str(hub_src), str(hub_dst)) trusted_src = default_torch_home / "hub" / "trusted_list" trusted_dst = torch_home / "hub" / "trusted_list" if trusted_dst.exists() is False and trusted_src.exists() is True: try: shutil.copyfile(str(trusted_src), str(trusted_dst)) except OSError: # Not critical; best-effort. pass # Redirect torch-hub cache to this folder. os.environ.setdefault("TORCH_HOME", str(torch_home)) return torch_home def tiny_roma_v1_outdoor(device, weights=None, xfeat=None): _ensure_local_torch_home_with_optional_weights() if weights is None: weights = torch.hub.load_state_dict_from_url( weight_urls["tiny_roma_v1"]["outdoor"], map_location=device ) if xfeat is None: # Important for offline use: # torch.hub.load("verlab/accelerated_features", ...) may still query GitHub # even if repository code is cached. If we have a local cached repo directory, # load from that directory directly. torch_home = Path(os.environ.get("TORCH_HOME", _project_root() / "torch_home")) hub_repo_dir = torch_home / "hub" / "verlab_accelerated_features_main" if hub_repo_dir.exists(): xfeat = torch.hub.load( str(hub_repo_dir), "XFeat", source="local", pretrained=True, top_k=4096, ).net else: xfeat = torch.hub.load( "verlab/accelerated_features", "XFeat", pretrained=True, top_k=4096, ).net return tiny_roma_v1_model(weights=weights, xfeat=xfeat).to(device) def roma_outdoor( device, weights=None, dinov2_weights=None, coarse_res: Union[int, tuple[int, int]] = 560, upsample_res: Union[int, tuple[int, int]] = 864, amp_dtype: torch.dtype = torch.float16, symmetric=True, use_custom_corr=True, upsample_preds=True, with_padding=False, do_compile=False, ): if torch.get_float32_matmul_precision() != "highest": raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35") if weights is None: weights = torch.hub.load_state_dict_from_url( weight_urls["romatch"]["outdoor"], map_location=device ) if dinov2_weights is None: dinov2_weights = torch.hub.load_state_dict_from_url( weight_urls["dinov2"], map_location=device ) model_init = roma_model if not with_padding else roma_model_pad model = model_init( resolution=coarse_res, upsample_preds=upsample_preds, weights=weights, dinov2_weights=dinov2_weights, device=device, amp_dtype=amp_dtype, symmetric=symmetric, use_custom_corr=use_custom_corr, upsample_res=upsample_res, ) if do_compile: model.compile() return model def roma_indoor( device, weights=None, dinov2_weights=None, coarse_res: Union[int, tuple[int, int]] = 560, upsample_res: Union[int, tuple[int, int]] = 864, amp_dtype: torch.dtype = torch.float16, symmetric=True, use_custom_corr=True, upsample_preds=True, with_padding=False, do_compile=False, ): if torch.get_float32_matmul_precision() != "highest": raise RuntimeError("Float32 matmul precision must be set to highest for RoMa. See also https://github.com/Parskatt/RoMaV2/issues/35") if weights is None: weights = torch.hub.load_state_dict_from_url( weight_urls["romatch"]["indoor"], map_location=device ) if dinov2_weights is None: dinov2_weights = torch.hub.load_state_dict_from_url( weight_urls["dinov2"], map_location=device ) model_init = roma_model if not with_padding else roma_model_pad model = model_init( resolution=coarse_res, upsample_preds=upsample_preds, weights=weights, dinov2_weights=dinov2_weights, device=device, amp_dtype=amp_dtype, symmetric=symmetric, use_custom_corr=use_custom_corr, upsample_res=upsample_res, ) if do_compile: model.compile() return model